From 7b8b9372867c3d1f83ea93e87cf214b13e8a2573 Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Wed, 28 Jun 2023 13:28:03 +0000 Subject: [PATCH 01/10] Fused sampling Co-authored-by: Hesham Mostafa --- .../api/bench_fused_sample_neighbors.py | 39 ++ include/dgl/aten/csr.h | 66 ++++ include/dgl/sampling/neighbor.h | 50 +++ python/dgl/dataloading/neighbor_sampler.py | 57 ++- python/dgl/sampling/neighbor.py | 114 +++++- src/array/array.cc | 41 +++ src/array/array_op.h | 13 + src/array/cpu/rowwise_pick.h | 110 ++++++ src/array/cpu/rowwise_sampling.cc | 95 +++++ src/graph/sampling/neighbor/neighbor.cc | 336 ++++++++++++++++++ src/graph/unit_graph.cc | 30 ++ src/graph/unit_graph.h | 12 + tests/python/common/sampling/test_sampling.py | 296 +++++++++++---- 13 files changed, 1165 insertions(+), 94 deletions(-) create mode 100644 benchmarks/benchmarks/api/bench_fused_sample_neighbors.py diff --git a/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py b/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py new file mode 100644 index 000000000000..91fc6705d5eb --- /dev/null +++ b/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py @@ -0,0 +1,39 @@ +import time + +import dgl +import dgl.function as fn + +import numpy as np +import torch + +from .. import utils + + +@utils.benchmark("time") +@utils.parametrize_cpu("graph_name", ["livejournal", "reddit"]) +@utils.parametrize_gpu("graph_name", ["ogbn-arxiv", "reddit"]) +@utils.parametrize("format", ["csr", "csc"]) +@utils.parametrize("seed_nodes_num", [200, 5000, 20000]) +@utils.parametrize("fanout", [5, 20, 40]) +def track_time(graph_name, format, seed_nodes_num, fanout): + device = utils.get_bench_device() + graph = utils.get_graph(graph_name, format).to(device) + + edge_dir = "in" if format == "csc" else "out" + seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num) + seed_nodes = torch.from_numpy(seed_nodes).to(device) + + # dry run + for i in range(3): + dgl.sampling.sample_neighbors( + graph, seed_nodes, fanout, edge_dir=edge_dir, fused=True + ) + + # timing + with utils.Timer() as t: + for i in range(50): + dgl.sampling.sample_neighbors( + graph, seed_nodes, fanout, edge_dir=edge_dir, fused=True + ) + + return t.elapsed_secs / 50 diff --git a/include/dgl/aten/csr.h b/include/dgl/aten/csr.h index cbed0e41cc5e..3ef1571359b9 100644 --- a/include/dgl/aten/csr.h +++ b/include/dgl/aten/csr.h @@ -572,6 +572,72 @@ COOMatrix CSRRowWiseSampling( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask = NDArray(), bool replace = true); +/*! + * @brief Randomly select a fixed number of non-zero entries along each given + * row independently. + * + * The function performs random choices along each row independently. + * The picked indices are returned in the form of a CSR matrix, with + * additional IdArray that is an extended version of CSR's index pointers. + * + * With template parameter set to True rows are also saved as new seed nodes and + * mapped + * + * If replace is false and a row has fewer non-zero values than num_samples, + * all the values are picked. + * + * Examples: + * + * // csr.num_rows = 4; + * // csr.num_cols = 4; + * // csr.indptr = [0, 2, 3, 3, 5] + * // csr.indices = [0, 1, 1, 2, 3] + * // csr.data = [2, 3, 0, 1, 4] + * CSRMatrix csr = ...; + * IdArray rows = ... ; // [1, 3] + * IdArray seed_mapping = [-1, -1, -1, -1]; + * std::vector new_seed_nodes = {}; + * + * std::pair sampled = CSRRowWiseSamplingFused< + * typename IdType, True>( + * csr, rows, seed_mapping, + * new_seed_nodes, 2, + * FloatArray(), false); + * // possible sampled csr matrix: + * // sampled.first.num_rows = 2 + * // sampled.first.num_cols = 3 + * // sampled.first.indptr = [0, 1, 3] + * // sampled.first.indices = [1, 2, 3] + * // sampled.first.data = [0, 1, 4] + * // sampled.second = [0, 1, 1] + * // seed_mapping = [-1, 0, -1, 1]; + * // new_seed_nodes = {1, 3}; + * + * @tparam IdType Graph's index data type, can be int32_t or int64_t + * @tparam map_seed_nodes If set for true we map and copy rows to new_seed_nodes + * @param mat Input CSR matrix. + * @param rows Rows to sample from. + * @param seed_mapping Mapping array used if map_seed_nodes=true. If so each row + * from rows will be set to its position e.g. mapping[rows[i]] = i. + * @param new_seed_nodes Vector used if map_seed_nodes=true. If so it will + * contain rows. + * @param rows Rows to sample from. + * @param num_samples Number of samples + * @param prob_or_mask Unnormalized probability array or mask array. + * Should be of the same length as the data array. + * If an empty array is provided, assume uniform. + * @param replace True if sample with replacement + * @return A CSRMatrix storing the picked row, col and data indices, + * COO version of picked rows + * @note The edges of the entire graph must be ordered by their edge types, + * rows must be unique + */ +template +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask = NDArray(), bool replace = true); + /** * @brief Randomly select a fixed number of non-zero entries for each edge type * along each given row independently. diff --git a/include/dgl/sampling/neighbor.h b/include/dgl/sampling/neighbor.h index 7c17050777a2..375618eb77d9 100644 --- a/include/dgl/sampling/neighbor.h +++ b/include/dgl/sampling/neighbor.h @@ -9,6 +9,7 @@ #include #include +#include #include namespace dgl { @@ -47,6 +48,55 @@ HeteroSubgraph SampleNeighbors( const std::vector& probability, const std::vector& exclude_edges, bool replace = true); +/** + * @brief Sample from the neighbors of the given nodes and convert a graph into + * a bipartite-structured graph for message passing. + * + * Specifically, we create one node type \c ntype_l on the "left" side and + * another node type \c ntype_r on the "right" side for each node type \c ntype. + * The nodes of type \c ntype_r would contain the nodes designated by the + * caller, and node type \c ntype_l would contain the nodes that has an edge + * connecting to one of the designated nodes. + * + * The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r. + * When sampling with replacement, the sampled subgraph could have parallel + * edges. + * + * For sampling without replace, if fanout > the number of neighbors, all the + * neighbors will be sampled. + * + * Non-deterministic algorithm, requires nodes parameter to store unique Node + * IDs. + * + * @tparam IdType Graph's index data type, can be int32_t or int64_t + * @param hg The input graph. + * @param nodes Node IDs of each type. The vector length must be equal to the + * number of node types. Empty array is allowed. + * @param mapping External parameter that should be set to a vector of IdArrays + * filled with -1, required for mapping of nodes in returned + * graph + * @param fanouts Number of sampled neighbors for each edge type. The vector + * length should be equal to the number of edge types, or one if they all have + * the same fanout. + * @param dir Edge direction. + * @param probability A vector of 1D float arrays, indicating the transition + * probability of each edge by edge type. An empty float array assumes uniform + * transition. + * @param exclude_edges Edges IDs of each type which will be excluded during + * sampling. The vector length must be equal to the number of edges types. Empty + * array is allowed. + * @param replace If true, sample with replacement. + * @return Sampled neighborhoods as a graph. The return graph has the same + * schema as the original one. + */ +template +std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr hg, const std::vector& nodes, + std::vector& mapping, const std::vector& fanouts, + EdgeDir dir, const std::vector& prob_or_mask, + const std::vector& exclude_edges, bool replace = true); + /** * Select the neighbors with k-largest weights on the connecting edges for each * given node. diff --git a/python/dgl/dataloading/neighbor_sampler.py b/python/dgl/dataloading/neighbor_sampler.py index 573946eade9f..3243ee1d4dd6 100644 --- a/python/dgl/dataloading/neighbor_sampler.py +++ b/python/dgl/dataloading/neighbor_sampler.py @@ -1,4 +1,5 @@ """Data loading components for neighbor sampling""" +from .. import backend as F from ..base import EID, NID from ..transforms import to_block from .base import BlockSampler @@ -54,6 +55,9 @@ class NeighborSampler(BlockSampler): output_device : device, optional The device of the output subgraphs or MFGs. Default is the same as the minibatch of seed nodes. + fused : bool, default True + If True and device is CPU fused sample neighbors is invoked. This version + requires seed_nodes to be unique Examples -------- @@ -120,6 +124,7 @@ def __init__( prefetch_labels=None, prefetch_edge_feats=None, output_device=None, + fused=True, ): super().__init__( prefetch_node_feats=prefetch_node_feats, @@ -137,25 +142,47 @@ def __init__( ) self.prob = prob or mask self.replace = replace + self.fused = fused + self.mapping = {} + self.g = None def sample_blocks(self, g, seed_nodes, exclude_eids=None): output_nodes = seed_nodes blocks = [] - for fanout in reversed(self.fanouts): - frontier = g.sample_neighbors( - seed_nodes, - fanout, - edge_dir=self.edge_dir, - prob=self.prob, - replace=self.replace, - output_device=self.output_device, - exclude_edges=exclude_eids, - ) - eid = frontier.edata[EID] - block = to_block(frontier, seed_nodes) - block.edata[EID] = eid - seed_nodes = block.srcdata[NID] - blocks.insert(0, block) + if F.device_type(g.device) == "cpu" and self.fused: + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + fused=True, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + else: + for fanout in reversed(self.fanouts): + frontier = g.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + eid = frontier.edata[EID] + block = to_block(frontier, seed_nodes) + block.edata[EID] = eid + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) return seed_nodes, output_nodes, blocks diff --git a/python/dgl/sampling/neighbor.py b/python/dgl/sampling/neighbor.py index 232fb3972744..d96372ca39f6 100644 --- a/python/dgl/sampling/neighbor.py +++ b/python/dgl/sampling/neighbor.py @@ -1,9 +1,13 @@ """Neighbor sampling APIs""" +import os + +import torch + from .. import backend as F, ndarray as nd, utils from .._ffi.function import _init_api -from ..base import DGLError, EID -from ..heterograph import DGLGraph +from ..base import DGLError, EID, NID +from ..heterograph import DGLBlock, DGLGraph from .utils import EidExcluder __all__ = [ @@ -214,6 +218,8 @@ def sample_neighbors( _dist_training=False, exclude_edges=None, output_device=None, + fused=False, + mapping=None, ): """Sample neighboring edges of the given nodes and return the induced subgraph. @@ -282,6 +288,18 @@ def sample_neighbors( output_device : Framework-specific device context object, optional The output device. Default is the same as the input graph. + fused : bool, optional + Enables faster version of NeighborSampler that is also compacting output graph, + returning a computational block. Requires nodes to be unique + + (Default: False) + + mapping : dictionary, optional + Used by fused version of sample_neighbors. To avoid constant data allocation + provide empty dictionary ({}) that will be allocated once with proper data and reused + by each function call + + (Default: None) Returns ------- DGLGraph @@ -361,6 +379,8 @@ def sample_neighbors( copy_ndata=copy_ndata, copy_edata=copy_edata, exclude_edges=exclude_edges, + fused=fused, + mapping=mapping, ) else: frontier = _sample_neighbors( @@ -372,6 +392,8 @@ def sample_neighbors( replace=replace, copy_ndata=copy_ndata, copy_edata=copy_edata, + fused=fused, + mapping=mapping, ) if exclude_edges is not None: eid_excluder = EidExcluder(exclude_edges) @@ -390,6 +412,8 @@ def _sample_neighbors( copy_edata=True, _dist_training=False, exclude_edges=None, + fused=False, + mapping=None, ): if not isinstance(nodes, dict): if len(g.ntypes) > 1: @@ -446,17 +470,54 @@ def _sample_neighbors( else: excluded_edges_all_t.append(nd.array([], ctx=ctx)) - subgidx = _CAPI_DGLSampleNeighbors( - g._graph, - nodes_all_types, - fanout_array, - edge_dir, - prob_arrays, - excluded_edges_all_t, - replace, - ) - induced_edges = subgidx.induced_edges - ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) + if fused: + if _dist_training: + raise DGLError( + "distributed training not supported in fused sampling" + ) + if F.device_type(g.device) != "cpu": + raise DGLError("Only cpu is supported in fused sampling") + + if mapping == None: + mapping = {} + mapping_name = "__mapping" + str(os.getpid()) + if mapping_name not in mapping.keys(): + mapping[mapping_name] = [ + torch.LongTensor(g.num_nodes(ntype)).fill_(-1) + for ntype in g.ntypes + ] + + subgidx, induced_nodes, induced_edges = _CAPI_DGLSampleNeighborsFused( + g._graph, + nodes_all_types, + [F.to_dgl_nd(m) for m in mapping[mapping_name]], + fanout_array, + edge_dir, + prob_arrays, + excluded_edges_all_t, + replace, + ) + for mapping_vector, src_nodes in zip( + mapping[mapping_name], induced_nodes + ): + mapping_vector[F.from_dgl_nd(src_nodes).type(torch.int64)] = -1 + + new_ntypes = (g.ntypes, g.ntypes) + ret = DGLBlock(subgidx, new_ntypes, g.etypes) + assert ret.is_unibipartite + + else: + subgidx = _CAPI_DGLSampleNeighbors( + g._graph, + nodes_all_types, + fanout_array, + edge_dir, + prob_arrays, + excluded_edges_all_t, + replace, + ) + ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) + induced_edges = subgidx.induced_edges # handle features # (TODO) (BarclayII) DGL distributed fails with bus error, freezes, or other @@ -465,12 +526,31 @@ def _sample_neighbors( # only set the edge IDs. if not _dist_training: if copy_ndata: - node_frames = utils.extract_node_subframes(g, device) - utils.set_new_frames(ret, node_frames=node_frames) + if fused: + src_node_ids = [F.from_dgl_nd(src) for src in induced_nodes] + dst_node_ids = [ + utils.toindex( + nodes.get(ntype, []), g._idtype_str + ).tousertensor(ctx=F.to_backend_ctx(g._graph.ctx)) + for ntype in g.ntypes + ] + node_frames = utils.extract_node_subframes_for_block( + g, src_node_ids, dst_node_ids + ) + utils.set_new_frames(ret, node_frames=node_frames) + else: + node_frames = utils.extract_node_subframes(g, device) + utils.set_new_frames(ret, node_frames=node_frames) if copy_edata: - edge_frames = utils.extract_edge_subframes(g, induced_edges) - utils.set_new_frames(ret, edge_frames=edge_frames) + if fused: + edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges] + edge_frames = utils.extract_edge_subframes(g, edge_ids) + utils.set_new_frames(ret, edge_frames=edge_frames) + else: + edge_frames = utils.extract_edge_subframes(g, induced_edges) + utils.set_new_frames(ret, edge_frames=edge_frames) + else: for i, etype in enumerate(ret.canonical_etypes): ret.edges[etype].data[EID] = induced_edges[i] diff --git a/src/array/array.cc b/src/array/array.cc index 57a2af0761c8..d4cf3693e704 100644 --- a/src/array/array.cc +++ b/src/array/array.cc @@ -597,6 +597,47 @@ COOMatrix CSRRowWiseSampling( return ret; } +template +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask, bool replace) { + std::pair ret; + if (IsNullArray(prob_or_mask)) { + ATEN_XPU_SWITCH( + rows->ctx.device_type, XPU, "CSRRowWiseSamplingUniformFused", { + ret = + impl::CSRRowWiseSamplingUniformFused( + mat, rows, seed_mapping, new_seed_nodes, num_samples, + replace); + }); + } else { + CHECK_VALID_CONTEXT(prob_or_mask, rows); + ATEN_XPU_SWITCH(rows->ctx.device_type, XPU, "CSRRowWiseSamplingFused", { + ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH( + prob_or_mask->dtype, FloatType, "probability or mask", { + ret = impl::CSRRowWiseSamplingFused< + XPU, IdType, FloatType, map_seed_nodes>( + mat, rows, seed_mapping, new_seed_nodes, num_samples, + prob_or_mask, replace); + }); + }); + } + return ret; +} + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, const std::vector& num_samples, diff --git a/src/array/array_op.h b/src/array/array_op.h index 3b37db78fba2..91e1d2f1a56a 100644 --- a/src/array/array_op.h +++ b/src/array/array_op.h @@ -178,6 +178,14 @@ COOMatrix CSRRowWiseSampling( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace); +// FloatType is the type of probability data. +template < + DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes> +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask, bool replace); + // FloatType is the type of probability data. template COOMatrix CSRRowWisePerEtypeSampling( @@ -190,6 +198,11 @@ template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace); +template +std::pair CSRRowWiseSamplingUniformFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, bool replace); + template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, diff --git a/src/array/cpu/rowwise_pick.h b/src/array/cpu/rowwise_pick.h index cb0b32298157..2edcacba619e 100644 --- a/src/array/cpu/rowwise_pick.h +++ b/src/array/cpu/rowwise_pick.h @@ -14,6 +14,7 @@ #include #include #include +#include #include namespace dgl { @@ -94,6 +95,115 @@ using EtypeRangePickFn = std::function& et_idx, const std::vector& et_eid, const IdxType* eid, IdxType* out_idx)>; +template +std::pair CSRRowWisePickFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_picks, bool replace, + PickFn pick_fn, NumPicksFn num_picks_fn) { + using namespace aten; + + const IdxType* indptr = static_cast(mat.indptr->data); + const IdxType* indices = static_cast(mat.indices->data); + const IdxType* data = + CSRHasData(mat) ? static_cast(mat.data->data) : nullptr; + const IdxType* rows_data = static_cast(rows->data); + const int64_t num_rows = rows->shape[0]; + const auto& ctx = mat.indptr->ctx; + const auto& idtype = mat.indptr->dtype; + IdxType* seed_mapping_data = nullptr; + if (map_seed_nodes) seed_mapping_data = seed_mapping.Ptr(); + + const int num_threads = runtime::compute_num_threads(0, num_rows, 1); + std::vector global_prefix(num_threads + 1, 0); + + IdArray picked_col, picked_idx, picked_coo_rows; + + IdArray block_csr_indptr = IdArray::Empty({num_rows + 1}, idtype, ctx); + IdxType* block_csr_indptr_data = block_csr_indptr.Ptr(); + +#pragma omp parallel num_threads(num_threads) + { + const int thread_id = omp_get_thread_num(); + + const int64_t start_i = + thread_id * (num_rows / num_threads) + + std::min(static_cast(thread_id), num_rows % num_threads); + const int64_t end_i = + (thread_id + 1) * (num_rows / num_threads) + + std::min(static_cast(thread_id + 1), num_rows % num_threads); + assert(thread_id + 1 < num_threads || end_i == num_rows); + + const int64_t num_local = end_i - start_i; + + std::unique_ptr local_prefix(new int64_t[num_local + 1]); + local_prefix[0] = 0; + for (int64_t i = start_i; i < end_i; ++i) { + // build prefix-sum + const int64_t local_i = i - start_i; + const IdxType rid = rows_data[i]; + if (map_seed_nodes) seed_mapping_data[rid] = i; + + IdxType len = num_picks_fn( + rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data); + local_prefix[local_i + 1] = local_prefix[local_i] + len; + } + global_prefix[thread_id + 1] = local_prefix[num_local]; + +#pragma omp barrier +#pragma omp master + { + for (int t = 0; t < num_threads; ++t) { + global_prefix[t + 1] += global_prefix[t]; + } + picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); + picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); + picked_coo_rows = + IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); + } + +#pragma omp barrier + IdxType* picked_cdata = picked_col.Ptr(); + IdxType* picked_idata = picked_idx.Ptr(); + IdxType* picked_rows = picked_coo_rows.Ptr(); + + const IdxType thread_offset = global_prefix[thread_id]; + + for (int64_t i = start_i; i < end_i; ++i) { + const IdxType rid = rows_data[i]; + const int64_t local_i = i - start_i; + block_csr_indptr_data[i] = local_prefix[local_i] + thread_offset; + + const IdxType off = indptr[rid]; + const IdxType len = indptr[rid + 1] - off; + if (len == 0) continue; + + const int64_t row_offset = local_prefix[local_i] + thread_offset; + const int64_t num_picks = + local_prefix[local_i + 1] + thread_offset - row_offset; + + pick_fn( + rid, off, len, num_picks, indices, data, picked_idata + row_offset); + for (int64_t j = 0; j < num_picks; ++j) { + const IdxType picked = picked_idata[row_offset + j]; + picked_cdata[row_offset + j] = indices[picked]; + picked_idata[row_offset + j] = data ? data[picked] : picked; + picked_rows[row_offset + j] = i; + } + } + } + block_csr_indptr_data[num_rows] = global_prefix.back(); + + const IdxType num_cols = picked_col->shape[0]; + if (map_seed_nodes) { + new_seed_nodes.resize(num_rows); + memcpy(new_seed_nodes.data(), rows_data, sizeof(IdxType) * num_rows); + } + + return std::make_pair( + CSRMatrix(num_rows, num_cols, block_csr_indptr, picked_col, picked_idx), + picked_coo_rows); +} + // Template for picking non-zero values row-wise. The implementation utilizes // OpenMP parallelization on rows because each row performs computation // independently. diff --git a/src/array/cpu/rowwise_sampling.cc b/src/array/cpu/rowwise_sampling.cc index 56bb9a40bb4c..911f69b42371 100644 --- a/src/array/cpu/rowwise_sampling.cc +++ b/src/array/cpu/rowwise_sampling.cc @@ -225,6 +225,74 @@ template COOMatrix CSRRowWiseSampling( template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); +template < + DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes> +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask, bool replace) { + // If num_samples is -1, select all neighbors without replacement. + replace = (replace && num_samples != -1); + CHECK(prob_or_mask.defined()); + auto num_picks_fn = + GetSamplingNumPicksFn(num_samples, prob_or_mask, replace); + auto pick_fn = + GetSamplingPickFn(num_samples, prob_or_mask, replace); + return CSRRowWisePickFused( + mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn, + num_picks_fn); +} + +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, @@ -283,6 +351,33 @@ template COOMatrix CSRRowWiseSamplingUniform( template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix, IdArray, int64_t, bool); +template +std::pair CSRRowWiseSamplingUniformFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, bool replace) { + // If num_samples is -1, select all neighbors without replacement. + replace = (replace && num_samples != -1); + auto num_picks_fn = + GetSamplingUniformNumPicksFn(num_samples, replace); + auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); + return CSRRowWisePickFused( + mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn, + num_picks_fn); +} + +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); + template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, diff --git a/src/graph/sampling/neighbor/neighbor.cc b/src/graph/sampling/neighbor/neighbor.cc index 5a161f744534..ae9da862eb98 100644 --- a/src/graph/sampling/neighbor/neighbor.cc +++ b/src/graph/sampling/neighbor/neighbor.cc @@ -6,8 +6,10 @@ #include #include +#include #include #include +#include #include #include @@ -22,6 +24,76 @@ using namespace dgl::aten; namespace dgl { namespace sampling { +template +void ExcludeCertainEdgesFused( + std::vector& sampled_graphs, std::vector& induced_edges, + std::vector& sampled_coo_rows, + const std::vector& exclude_edges, + std::vector* weights = nullptr) { + int etypes = sampled_graphs.size(); + std::vector remain_induced_edges(etypes); + std::vector remain_indptrs(etypes); + std::vector remain_indices(etypes); + std::vector remain_coo_rows(etypes); + std::vector remain_weights(etypes); + for (int etype = 0; etype < etypes; ++etype) { + if (exclude_edges[etype].GetSize() == 0 || + sampled_graphs[etype].num_rows == 0) { + remain_induced_edges[etype] = induced_edges[etype]; + if (weights) remain_weights[etype] = (*weights)[etype]; + continue; + } + const auto dtype = weights && (*weights)[etype]->shape[0] + ? (*weights)[etype]->dtype + : DGLDataType{kDGLFloat, 8 * sizeof(float), 1}; + ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "weights", { + IdType* indptr = sampled_graphs[etype].indptr.Ptr(); + IdType* indices = sampled_graphs[etype].indices.Ptr(); + IdType* coo_rows = sampled_coo_rows[etype].Ptr(); + IdType* induced_edges_data = induced_edges[etype].Ptr(); + FloatType* weights_data = weights && (*weights)[etype]->shape[0] + ? (*weights)[etype].Ptr() + : nullptr; + const IdType exclude_edges_len = exclude_edges[etype]->shape[0]; + std::sort( + exclude_edges[etype].Ptr(), + exclude_edges[etype].Ptr() + exclude_edges_len); + const IdType* exclude_edges_data = exclude_edges[etype].Ptr(); + IdType outIndices = 0; + for (IdType row = 0; row < sampled_graphs[etype].indptr->shape[0] - 1; + ++row) { + auto tmp_row = indptr[row]; + if (outIndices != indptr[row]) indptr[row] = outIndices; + for (IdType col = tmp_row; col < indptr[row + 1]; ++col) { + if (!std::binary_search( + exclude_edges_data, exclude_edges_data + exclude_edges_len, + induced_edges_data[col])) { + indices[outIndices] = indices[col]; + induced_edges_data[outIndices] = induced_edges_data[col]; + coo_rows[outIndices] = coo_rows[col]; + if (weights_data) weights_data[outIndices] = weights_data[col]; + ++outIndices; + } + } + } + indptr[sampled_graphs[etype].indptr->shape[0] - 1] = outIndices; + remain_induced_edges[etype] = + aten::IndexSelect(induced_edges[etype], 0, outIndices); + remain_weights[etype] = + weights_data ? aten::IndexSelect((*weights)[etype], 0, outIndices) + : NullArray(); + remain_indices[etype] = + aten::IndexSelect(sampled_graphs[etype].indices, 0, outIndices); + sampled_coo_rows[etype] = + aten::IndexSelect(sampled_coo_rows[etype], 0, outIndices); + sampled_graphs[etype] = CSRMatrix( + sampled_graphs[etype].num_rows, outIndices, + sampled_graphs[etype].indptr, remain_indices[etype], + remain_induced_edges[etype]); + }); + } +} + std::pair> ExcludeCertainEdges( const HeteroSubgraph& sg, const std::vector& exclude_edges, const std::vector* weights = nullptr) { @@ -266,6 +338,229 @@ HeteroSubgraph SampleNeighbors( return ret; } +template +std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr hg, const std::vector& nodes, + std::vector& mapping, const std::vector& fanouts, + EdgeDir dir, const std::vector& prob_or_mask, + const std::vector& exclude_edges, bool replace) { + CHECK_EQ(nodes.size(), hg->NumVertexTypes()) + << "Number of node ID tensors must match the number of node types."; + CHECK_EQ(fanouts.size(), hg->NumEdgeTypes()) + << "Number of fanout values must match the number of edge types."; + CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes()) + << "Number of probability tensors must match the number of edge types."; + + DGLContext ctx = aten::GetContextOf(nodes); + + std::vector sampled_graphs; + std::vector sampled_coo_rows; + std::vector induced_edges; + std::vector induced_vertices; + std::vector num_nodes_per_type; + std::vector> new_nodes_vec(hg->NumVertexTypes()); + std::vector seed_nodes_mapped(hg->NumVertexTypes(), 0); + + for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) { + auto pair = hg->meta_graph()->FindEdge(etype); + const dgl_type_t src_vtype = pair.first; + const dgl_type_t dst_vtype = pair.second; + const dgl_type_t rhs_node_type = + (dir == EdgeDir::kOut) ? src_vtype : dst_vtype; + const IdArray nodes_ntype = nodes[rhs_node_type]; + const int64_t num_nodes = nodes_ntype->shape[0]; + + if (num_nodes == 0 || fanouts[etype] == 0) { + // Nothing to sample for this etype, create a placeholder + sampled_graphs.push_back(CSRMatrix()); + sampled_coo_rows.push_back(IdArray()); + induced_edges.push_back(aten::NullArray(hg->DataType(), ctx)); + } else { + bool map_seed_nodes = !seed_nodes_mapped[rhs_node_type]; + // sample from one relation graph + std::pair sampled_graph; + auto sampling_fn = map_seed_nodes + ? aten::CSRRowWiseSamplingFused + : aten::CSRRowWiseSamplingFused; + auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE; + auto avail_fmt = hg->SelectFormat(etype, req_fmt); + switch (avail_fmt) { + case SparseFormat::kCSR: + CHECK(dir == EdgeDir::kOut) + << "Cannot sample out edges on CSC matrix."; + // In heterographs nodes of two diffrent types can be connected + // therefore two diffrent mappings and node vectors are needed + sampled_graph = sampling_fn( + hg->GetCSRMatrix(etype), nodes_ntype, mapping[src_vtype], + new_nodes_vec[src_vtype], fanouts[etype], prob_or_mask[etype], + replace); + break; + case SparseFormat::kCSC: + CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; + sampled_graph = sampling_fn( + hg->GetCSCMatrix(etype), nodes_ntype, mapping[dst_vtype], + new_nodes_vec[dst_vtype], fanouts[etype], prob_or_mask[etype], + replace); + break; + default: + LOG(FATAL) << "Unsupported sparse format."; + } + seed_nodes_mapped[rhs_node_type]++; + sampled_graphs.push_back(sampled_graph.first); + if (sampled_graph.first.data.defined()) + induced_edges.push_back(sampled_graph.first.data); + else + induced_edges.push_back( + aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx)); + sampled_coo_rows.push_back(sampled_graph.second); + } + } + + if (!exclude_edges.empty()) { + ExcludeCertainEdgesFused( + sampled_graphs, induced_edges, sampled_coo_rows, exclude_edges); + for (size_t i = 0; i < hg->NumEdgeTypes(); i++) { + if (sampled_graphs[i].data.defined()) + induced_edges[i] = std::move(sampled_graphs[i].data); + else + induced_edges[i] = + aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx); + } + } + + // map indices + for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) { + auto pair = hg->meta_graph()->FindEdge(etype); + const dgl_type_t src_vtype = pair.first; + const dgl_type_t dst_vtype = pair.second; + const dgl_type_t lhs_node_type = + (dir == EdgeDir::kIn) ? src_vtype : dst_vtype; + if (sampled_graphs[etype].num_cols != 0) { + auto num_cols = sampled_graphs[etype].num_cols; + const int num_threads_col = runtime::compute_num_threads(0, num_cols, 1); + std::vector global_prefix_col(num_threads_col + 1, 0); + std::vector> src_nodes_local(num_threads_col); + IdType* mapping_data_dst = mapping[lhs_node_type].Ptr(); + IdType* cdata = sampled_graphs[etype].indices.Ptr(); +#pragma omp parallel num_threads(num_threads_col) + { + const int thread_id = omp_get_thread_num(); + + const int64_t start_i = + thread_id * (num_cols / num_threads_col) + + std::min( + static_cast(thread_id), num_cols % num_threads_col); + const int64_t end_i = (thread_id + 1) * (num_cols / num_threads_col) + + std::min( + static_cast(thread_id + 1), + num_cols % num_threads_col); + assert(thread_id + 1 < num_threads_col || end_i == num_cols); + for (int64_t i = start_i; i < end_i; ++i) { + int64_t picked_idx = cdata[i]; + bool spot_claimed = __sync_bool_compare_and_swap( + &mapping_data_dst[picked_idx], -1, 0); + if (spot_claimed) src_nodes_local[thread_id].push_back(picked_idx); + } + global_prefix_col[thread_id + 1] = src_nodes_local[thread_id].size(); + +#pragma omp barrier +#pragma omp master + { + global_prefix_col[0] = new_nodes_vec[lhs_node_type].size(); + for (int t = 0; t < num_threads_col; ++t) { + global_prefix_col[t + 1] += global_prefix_col[t]; + } + } + +#pragma omp barrier + int64_t mapping_shift = global_prefix_col[thread_id]; + for (size_t i = 0; i < src_nodes_local[thread_id].size(); ++i) + mapping_data_dst[src_nodes_local[thread_id][i]] = mapping_shift + i; + +#pragma omp barrier + for (int64_t i = start_i; i < end_i; ++i) { + IdType picked_idx = cdata[i]; + IdType mapped_idx = mapping_data_dst[picked_idx]; + cdata[i] = mapped_idx; + } + } + IdType offset = new_nodes_vec[lhs_node_type].size(); + new_nodes_vec[lhs_node_type].resize(global_prefix_col.back()); + for (int thread_id = 0; thread_id < num_threads_col; ++thread_id) { + memcpy( + new_nodes_vec[lhs_node_type].data() + offset, + &src_nodes_local[thread_id][0], + src_nodes_local[thread_id].size() * sizeof(IdType)); + offset += src_nodes_local[thread_id].size(); + } + } + } + + // counting how many nodes of each ntype were sampled + num_nodes_per_type.resize(2 * hg->NumVertexTypes()); + for (size_t i = 0; i < hg->NumVertexTypes(); i++) { + num_nodes_per_type[i] = new_nodes_vec[i].size(); + num_nodes_per_type[hg->NumVertexTypes() + i] = nodes[i]->shape[0]; + induced_vertices.push_back( + VecToIdArray(new_nodes_vec[i], sizeof(IdType) * 8)); + } + + std::vector subrels(hg->NumEdgeTypes()); + for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) { + auto pair = hg->meta_graph()->FindEdge(etype); + const dgl_type_t src_vtype = pair.first; + const dgl_type_t dst_vtype = pair.second; + if (sampled_graphs[etype].num_rows == 0) { + subrels[etype] = UnitGraph::Empty( + 2, new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0], + hg->DataType(), ctx); + } else { + CSRMatrix graph = sampled_graphs[etype]; + if (dir == EdgeDir::kOut) { + subrels[etype] = UnitGraph::CreateFromCSRAndCOO( + 2, + CSRMatrix( + nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(), + graph.indptr, graph.indices, + Range( + 0, graph.indices->shape[0], graph.indices->dtype.bits, + ctx)), + COOMatrix( + nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(), + sampled_coo_rows[etype], graph.indices), + ALL_CODE); + } else { + subrels[etype] = UnitGraph::CreateFromCSCAndCOO( + 2, + CSRMatrix( + nodes[dst_vtype]->shape[0], new_nodes_vec[src_vtype].size(), + graph.indptr, graph.indices, + Range( + 0, graph.indices->shape[0], graph.indices->dtype.bits, + ctx)), + COOMatrix( + new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0], + graph.indices, sampled_coo_rows[etype]), + ALL_CODE); + } + } + } + + HeteroSubgraph ret; + + const auto meta_graph = hg->meta_graph(); + const EdgeArray etypes = meta_graph->Edges("eid"); + const IdArray new_dst = Add(etypes.dst, hg->NumVertexTypes()); + + const auto new_meta_graph = ImmutableGraph::CreateFromCOO( + hg->NumVertexTypes() * 2, etypes.src, new_dst); + + HeteroGraphPtr new_graph = + CreateHeteroGraph(new_meta_graph, subrels, num_nodes_per_type); + return std::make_tuple(new_graph, induced_edges, induced_vertices); +} + HeteroSubgraph SampleNeighborsEType( const HeteroGraphPtr hg, const IdArray nodes, const std::vector& eid2etype_offset, @@ -568,6 +863,47 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors") *rv = HeteroSubgraphRef(subg); }); +DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsFused") + .set_body([](DGLArgs args, DGLRetValue* rv) { + HeteroGraphRef hg = args[0]; + const auto& nodes = ListValueToVector(args[1]); + auto mapping = ListValueToVector(args[2]); + IdArray fanouts_array = args[3]; + const auto& fanouts = fanouts_array.ToVector(); + const std::string dir_str = args[4]; + const auto& prob_or_mask = ListValueToVector(args[5]); + const auto& exclude_edges = ListValueToVector(args[6]); + const bool replace = args[7]; + + CHECK(dir_str == "in" || dir_str == "out") + << "Invalid edge direction. Must be \"in\" or \"out\"."; + EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut; + + HeteroGraphPtr new_graph; + std::vector induced_edges; + std::vector induced_vertices; + + ATEN_ID_TYPE_SWITCH(hg->DataType(), IdType, { + std::tie(new_graph, induced_edges, induced_vertices) = + SampleNeighborsFused( + hg.sptr(), nodes, mapping, fanouts, dir, prob_or_mask, + exclude_edges, replace); + }); + + List lhs_nodes_ref; + for (IdArray& array : induced_vertices) + lhs_nodes_ref.push_back(Value(MakeValue(array))); + List induced_edges_ref; + for (IdArray& array : induced_edges) + induced_edges_ref.push_back(Value(MakeValue(array))); + List ret; + ret.push_back(HeteroGraphRef(new_graph)); + ret.push_back(lhs_nodes_ref); + ret.push_back(induced_edges_ref); + + *rv = ret; + }); + DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; diff --git a/src/graph/unit_graph.cc b/src/graph/unit_graph.cc index 3c26cf92f82e..b4e29de3a39a 100644 --- a/src/graph/unit_graph.cc +++ b/src/graph/unit_graph.cc @@ -1218,6 +1218,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSR( return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats)); } +HeteroGraphPtr UnitGraph::CreateFromCSRAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csr, const aten::COOMatrix& coo, + dgl_format_code_t formats) { + CHECK(num_vtypes == 1 || num_vtypes == 2); + CHECK_EQ(coo.num_rows, csr.num_rows); + CHECK_EQ(coo.num_cols, csr.num_cols); + if (num_vtypes == 1) { + CHECK_EQ(csr.num_rows, csr.num_cols); + } + auto mg = CreateUnitGraphMetaGraph(num_vtypes); + CSRPtr csrPtr(new CSR(mg, csr)); + COOPtr cooPtr(new COO(mg, coo)); + return HeteroGraphPtr(new UnitGraph(mg, nullptr, csrPtr, cooPtr, formats)); +} + HeteroGraphPtr UnitGraph::CreateFromCSC( int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) { @@ -1237,6 +1252,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSC( return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats)); } +HeteroGraphPtr UnitGraph::CreateFromCSCAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csc, const aten::COOMatrix& coo, + dgl_format_code_t formats) { + CHECK(num_vtypes == 1 || num_vtypes == 2); + CHECK_EQ(coo.num_rows, csc.num_cols); + CHECK_EQ(coo.num_cols, csc.num_rows); + if (num_vtypes == 1) { + CHECK_EQ(csc.num_rows, csc.num_cols); + } + auto mg = CreateUnitGraphMetaGraph(num_vtypes); + CSRPtr cscPtr(new CSR(mg, csc)); + COOPtr cooPtr(new COO(mg, coo)); + return HeteroGraphPtr(new UnitGraph(mg, cscPtr, nullptr, cooPtr, formats)); +} + HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { if (g->NumBits() == bits) { return g; diff --git a/src/graph/unit_graph.h b/src/graph/unit_graph.h index 5951c5001f5c..244f922e906d 100644 --- a/src/graph/unit_graph.h +++ b/src/graph/unit_graph.h @@ -190,6 +190,12 @@ class UnitGraph : public BaseHeteroGraph { int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats = ALL_CODE); + /** @brief Create a graph from (out) CSR and COO arrays, both representing the + * same graph */ + static HeteroGraphPtr CreateFromCSRAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csr, + const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE); + /** @brief Create a graph from (in) CSC arrays */ static HeteroGraphPtr CreateFromCSC( int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, @@ -199,6 +205,12 @@ class UnitGraph : public BaseHeteroGraph { int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats = ALL_CODE); + /** @brief Create a graph from (in) CSC and COO arrays, both representing the + * same graph */ + static HeteroGraphPtr CreateFromCSCAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csc, + const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE); + /** @brief Convert the graph to use the given number of bits for storage */ static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); diff --git a/tests/python/common/sampling/test_sampling.py b/tests/python/common/sampling/test_sampling.py index f62ba858afc6..650e502f57e3 100644 --- a/tests/python/common/sampling/test_sampling.py +++ b/tests/python/common/sampling/test_sampling.py @@ -555,15 +555,18 @@ def _gen_neighbor_topk_test_graph(hypersparse, reverse): return g, hg -def _test_sample_neighbors(hypersparse, prob): +def _test_sample_neighbors(hypersparse, prob, fused): g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False) def _test1(p, replace): subg = dgl.sampling.sample_neighbors( - g, [0, 1], -1, prob=p, replace=replace + g, [0, 1], -1, prob=p, replace=replace, fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -577,11 +580,16 @@ def _test1(p, replace): for i in range(10): subg = dgl.sampling.sample_neighbors( - g, [0, 1], 2, prob=p, replace=replace + g, [0, 1], 2, prob=p, replace=replace, fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() + assert subg.num_edges() == 4 u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] + assert set(F.asnumpy(F.unique(v))) == {0, 1} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -601,10 +609,13 @@ def _test1(p, replace): def _test2(p, replace): # fanout > #neighbors subg = dgl.sampling.sample_neighbors( - g, [0, 2], -1, prob=p, replace=replace + g, [0, 2], -1, prob=p, replace=replace, fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -618,12 +629,15 @@ def _test2(p, replace): # fanout > #neighbors for i in range(10): subg = dgl.sampling.sample_neighbors( - g, [0, 2], 2, prob=p, replace=replace + g, [0, 2], 2, prob=p, replace=replace, fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() num_edges = 4 if replace else 3 assert subg.num_edges() == num_edges u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] assert set(F.asnumpy(F.unique(v))) == {0, 2} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -642,9 +656,17 @@ def _test2(p, replace): # fanout > #neighbors def _test3(p, replace): subg = dgl.sampling.sample_neighbors( - hg, {"user": [0, 1], "game": 0}, -1, prob=p, replace=replace + hg, + {"user": [0, 1], "game": 0}, + -1, + prob=p, + replace=replace, + fused=fused, ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 6 if p is None else 4 assert subg["play"].num_edges() == 1 @@ -653,9 +675,17 @@ def _test3(p, replace): for i in range(10): subg = dgl.sampling.sample_neighbors( - hg, {"user": [0, 1], "game": 0}, 2, prob=p, replace=replace + hg, + {"user": [0, 1], "game": 0}, + 2, + prob=p, + replace=replace, + fused=fused, ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 4 assert subg["play"].num_edges() == 2 if replace else 1 @@ -672,8 +702,12 @@ def _test3(p, replace): {"user": [0, 1], "game": 0, "coin": 0}, {"follow": 1, "play": 2, "liked-by": 0, "flips": -1}, replace=True, + fused=fused, ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 2 assert subg["play"].num_edges() == 2 @@ -795,15 +829,19 @@ def _test3(p): assert subg["flips"].num_edges() == 4 -def _test_sample_neighbors_outedge(hypersparse): +def _test_sample_neighbors_outedge(hypersparse, fused): g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True) def _test1(p, replace): subg = dgl.sampling.sample_neighbors( - g, [0, 1], -1, prob=p, replace=replace, edge_dir="out" + g, [0, 1], -1, prob=p, replace=replace, edge_dir="out", fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() + u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] u_ans, v_ans, e_ans = g.out_edges([0, 1], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -817,11 +855,20 @@ def _test1(p, replace): for i in range(10): subg = dgl.sampling.sample_neighbors( - g, [0, 1], 2, prob=p, replace=replace, edge_dir="out" + g, + [0, 1], + 2, + prob=p, + replace=replace, + edge_dir="out", + fused=fused, ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() assert subg.num_edges() == 4 u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] assert set(F.asnumpy(F.unique(u))) == {0, 1} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -843,10 +890,13 @@ def _test1(p, replace): def _test2(p, replace): # fanout > #neighbors subg = dgl.sampling.sample_neighbors( - g, [0, 2], -1, prob=p, replace=replace, edge_dir="out" + g, [0, 2], -1, prob=p, replace=replace, edge_dir="out", fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] u_ans, v_ans, e_ans = g.out_edges([0, 2], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -860,12 +910,22 @@ def _test2(p, replace): # fanout > #neighbors for i in range(10): subg = dgl.sampling.sample_neighbors( - g, [0, 2], 2, prob=p, replace=replace, edge_dir="out" + g, + [0, 2], + 2, + prob=p, + replace=replace, + edge_dir="out", + fused=fused, ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() num_edges = 4 if replace else 3 assert subg.num_edges() == num_edges u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] + assert set(F.asnumpy(F.unique(u))) == {0, 2} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -892,8 +952,13 @@ def _test3(p, replace): prob=p, replace=replace, edge_dir="out", + fused=fused, ) - assert len(subg.ntypes) == 3 + + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 6 if p is None else 4 assert subg["play"].num_edges() == 1 @@ -908,8 +973,12 @@ def _test3(p, replace): prob=p, replace=replace, edge_dir="out", + fused=fused, ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 4 assert subg["play"].num_edges() == 2 if replace else 1 @@ -1077,7 +1146,9 @@ def _test3(): def test_sample_neighbors_noprob(): - _test_sample_neighbors(False, None) + _test_sample_neighbors(False, None, False) + if F._default_context_str != "gpu": + _test_sample_neighbors(False, None, True) # _test_sample_neighbors(True) @@ -1086,7 +1157,9 @@ def test_sample_labors_noprob(): def test_sample_neighbors_prob(): - _test_sample_neighbors(False, "prob") + _test_sample_neighbors(False, "prob", False) + if F._default_context_str != "gpu": + _test_sample_neighbors(False, "prob", True) # _test_sample_neighbors(True) @@ -1095,7 +1168,9 @@ def test_sample_labors_prob(): def test_sample_neighbors_outedge(): - _test_sample_neighbors_outedge(False) + _test_sample_neighbors_outedge(False, False) + if F._default_context_str != "gpu": + _test_sample_neighbors_outedge(False, True) # _test_sample_neighbors_outedge(True) @@ -1107,7 +1182,8 @@ def test_sample_neighbors_outedge(): reason="GPU sample neighbors with mask not implemented", ) def test_sample_neighbors_mask(): - _test_sample_neighbors(False, "mask") + _test_sample_neighbors(False, "mask", False) + _test_sample_neighbors(False, "mask", True) @unittest.skipIf( @@ -1128,22 +1204,45 @@ def test_sample_neighbors_topk_outedge(): # _test_sample_neighbors_topk_outedge(True) -def test_sample_neighbors_with_0deg(): +@pytest.mark.parametrize("fused", [False, True]) +def test_sample_neighbors_with_0deg(fused): + if fused and F._default_context_str == "gpu": + pytest.skip("Fused sampling doesn't support GPU.") g = dgl.graph(([], []), num_nodes=5).to(F.ctx()) sg = dgl.sampling.sample_neighbors( - g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=False + g, + F.tensor([1, 2], dtype=F.int64), + 2, + edge_dir="in", + replace=False, + fused=fused, ) assert sg.num_edges() == 0 sg = dgl.sampling.sample_neighbors( - g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=True + g, + F.tensor([1, 2], dtype=F.int64), + 2, + edge_dir="in", + replace=True, + fused=fused, ) assert sg.num_edges() == 0 sg = dgl.sampling.sample_neighbors( - g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=False + g, + F.tensor([1, 2], dtype=F.int64), + 2, + edge_dir="out", + replace=False, + fused=fused, ) assert sg.num_edges() == 0 sg = dgl.sampling.sample_neighbors( - g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=True + g, + F.tensor([1, 2], dtype=F.int64), + 2, + edge_dir="out", + replace=True, + fused=fused, ) assert sg.num_edges() == 0 @@ -1274,7 +1373,7 @@ def check_num(nodes, tag): ) def test_sample_neighbors_biased_bipartite(): g = create_test_graph(100, 30, True) - num_dst = g.number_of_dst_nodes() + num_dst = g.num_dst_nodes() bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32) def check_num(nodes, tag): @@ -1492,7 +1591,10 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction): @pytest.mark.parametrize("dtype", ["int32", "int64"]) -def test_sample_neighbors_exclude_edges_heteroG(dtype): +@pytest.mark.parametrize("fused", [False, True]) +def test_sample_neighbors_exclude_edges_heteroG(dtype, fused): + if fused and F._default_context_str == "gpu": + pytest.skip("Fused sampling doesn't support GPU.") d_i_d_u_nodes = F.zerocopy_from_numpy( np.unique(np.random.randint(300, size=100, dtype=dtype)) ) @@ -1574,39 +1676,85 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype): }, sampled_amount, exclude_edges=excluded_edges, + fused=fused, ) - assert not np.any( - F.asnumpy( - sg.has_edges_between( - did_excluded_nodes_U, - did_excluded_nodes_V, - etype=("drug", "interacts", "drug"), + if fused: + + def contain_edge(g, sg, etype, u, v): + # set of subgraph graph edges deduced from original graph + org_edges = set( + map( + tuple, + np.stack( + g.find_edges(sg.edges[etype].data[dgl.EID], etype), + axis=1, + ), + ) ) + # set of excluded edges + excluded_edges = set(map(tuple, np.stack((u, v), axis=1))) + + diff_set = org_edges - excluded_edges + + return len(diff_set) != len(org_edges) + + assert not contain_edge( + g, + sg, + ("drug", "interacts", "drug"), + did_excluded_nodes_U, + did_excluded_nodes_V, ) - ) - assert not np.any( - F.asnumpy( - sg.has_edges_between( - dig_excluded_nodes_U, - dig_excluded_nodes_V, - etype=("drug", "interacts", "gene"), + assert not contain_edge( + g, + sg, + ("drug", "interacts", "gene"), + dig_excluded_nodes_U, + dig_excluded_nodes_V, + ) + assert not contain_edge( + g, + sg, + ("drug", "treats", "disease"), + dtd_excluded_nodes_U, + dtd_excluded_nodes_V, + ) + else: + assert not np.any( + F.asnumpy( + sg.has_edges_between( + did_excluded_nodes_U, + did_excluded_nodes_V, + etype=("drug", "interacts", "drug"), + ) ) ) - ) - assert not np.any( - F.asnumpy( - sg.has_edges_between( - dtd_excluded_nodes_U, - dtd_excluded_nodes_V, - etype=("drug", "treats", "disease"), + assert not np.any( + F.asnumpy( + sg.has_edges_between( + dig_excluded_nodes_U, + dig_excluded_nodes_V, + etype=("drug", "interacts", "gene"), + ) + ) + ) + assert not np.any( + F.asnumpy( + sg.has_edges_between( + dtd_excluded_nodes_U, + dtd_excluded_nodes_V, + etype=("drug", "treats", "disease"), + ) ) ) - ) @pytest.mark.parametrize("dtype", ["int32", "int64"]) -def test_sample_neighbors_exclude_edges_homoG(dtype): +@pytest.mark.parametrize("fused", [False, True]) +def test_sample_neighbors_exclude_edges_homoG(dtype, fused): + if fused and F._default_context_str == "gpu": + pytest.skip("Fused sampling doesn't support GPU.") u_nodes = F.zerocopy_from_numpy( np.unique(np.random.randint(300, size=100, dtype=dtype)) ) @@ -1630,12 +1778,36 @@ def test_sample_neighbors_exclude_edges_homoG(dtype): excluded_nodes_V = g_edges[V][b_idx:e_idx] sg = dgl.sampling.sample_neighbors( - g, sampled_node, sampled_amount, exclude_edges=excluded_edges + g, + sampled_node, + sampled_amount, + exclude_edges=excluded_edges, + fused=fused, ) + if fused: + + def contain_edge(g, sg, u, v): + # set of subgraph graph edges deduced from original graph + org_edges = set( + map( + tuple, + np.stack( + g.find_edges(sg.edges["_E"].data[dgl.EID]), axis=1 + ), + ) + ) + # set of excluded edges + excluded_edges = set(map(tuple, np.stack((u, v), axis=1))) - assert not np.any( - F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V)) - ) + diff_set = org_edges - excluded_edges + + return len(diff_set) != len(org_edges) + + assert not contain_edge(g, sg, excluded_nodes_U, excluded_nodes_V) + else: + assert not np.any( + F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V)) + ) @pytest.mark.parametrize("dtype", ["int32", "int64"]) From 646a528f6da3cc3881a17243f29714845dc07721 Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Wed, 28 Jun 2023 14:25:39 +0000 Subject: [PATCH 02/10] Replace non-constant references with pointers. --- include/dgl/aten/csr.h | 2 +- include/dgl/sampling/neighbor.h | 2 +- src/array/array.cc | 10 ++--- src/array/array_op.h | 4 +- src/array/cpu/rowwise_pick.h | 6 +-- src/array/cpu/rowwise_sampling.cc | 44 +++++++++---------- src/graph/sampling/neighbor/neighbor.cc | 56 +++++++++++++++---------- 7 files changed, 68 insertions(+), 56 deletions(-) diff --git a/include/dgl/aten/csr.h b/include/dgl/aten/csr.h index 3ef1571359b9..6886aa6f1976 100644 --- a/include/dgl/aten/csr.h +++ b/include/dgl/aten/csr.h @@ -635,7 +635,7 @@ COOMatrix CSRRowWiseSampling( template std::pair CSRRowWiseSamplingFused( CSRMatrix mat, IdArray rows, IdArray seed_mapping, - std::vector& new_seed_nodes, int64_t num_samples, + std::vector* new_seed_nodes, int64_t num_samples, NDArray prob_or_mask = NDArray(), bool replace = true); /** diff --git a/include/dgl/sampling/neighbor.h b/include/dgl/sampling/neighbor.h index 375618eb77d9..1c3a4bb64eb3 100644 --- a/include/dgl/sampling/neighbor.h +++ b/include/dgl/sampling/neighbor.h @@ -93,7 +93,7 @@ template std::tuple, std::vector> SampleNeighborsFused( const HeteroGraphPtr hg, const std::vector& nodes, - std::vector& mapping, const std::vector& fanouts, + const std::vector& mapping, const std::vector& fanouts, EdgeDir dir, const std::vector& prob_or_mask, const std::vector& exclude_edges, bool replace = true); diff --git a/src/array/array.cc b/src/array/array.cc index d4cf3693e704..7472a4c3e09f 100644 --- a/src/array/array.cc +++ b/src/array/array.cc @@ -600,7 +600,7 @@ COOMatrix CSRRowWiseSampling( template std::pair CSRRowWiseSamplingFused( CSRMatrix mat, IdArray rows, IdArray seed_mapping, - std::vector& new_seed_nodes, int64_t num_samples, + std::vector* new_seed_nodes, int64_t num_samples, NDArray prob_or_mask, bool replace) { std::pair ret; if (IsNullArray(prob_or_mask)) { @@ -627,16 +627,16 @@ std::pair CSRRowWiseSamplingFused( } template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, diff --git a/src/array/array_op.h b/src/array/array_op.h index 91e1d2f1a56a..ceb15492a07c 100644 --- a/src/array/array_op.h +++ b/src/array/array_op.h @@ -183,7 +183,7 @@ template < DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes> std::pair CSRRowWiseSamplingFused( CSRMatrix mat, IdArray rows, IdArray seed_mapping, - std::vector& new_seed_nodes, int64_t num_samples, + std::vector* new_seed_nodes, int64_t num_samples, NDArray prob_or_mask, bool replace); // FloatType is the type of probability data. @@ -201,7 +201,7 @@ COOMatrix CSRRowWiseSamplingUniform( template std::pair CSRRowWiseSamplingUniformFused( CSRMatrix mat, IdArray rows, IdArray seed_mapping, - std::vector& new_seed_nodes, int64_t num_samples, bool replace); + std::vector* new_seed_nodes, int64_t num_samples, bool replace); template COOMatrix CSRRowWisePerEtypeSamplingUniform( diff --git a/src/array/cpu/rowwise_pick.h b/src/array/cpu/rowwise_pick.h index 2edcacba619e..7e34976e46f0 100644 --- a/src/array/cpu/rowwise_pick.h +++ b/src/array/cpu/rowwise_pick.h @@ -98,7 +98,7 @@ using EtypeRangePickFn = std::function std::pair CSRRowWisePickFused( CSRMatrix mat, IdArray rows, IdArray seed_mapping, - std::vector& new_seed_nodes, int64_t num_picks, bool replace, + std::vector* new_seed_nodes, int64_t num_picks, bool replace, PickFn pick_fn, NumPicksFn num_picks_fn) { using namespace aten; @@ -195,8 +195,8 @@ std::pair CSRRowWisePickFused( const IdxType num_cols = picked_col->shape[0]; if (map_seed_nodes) { - new_seed_nodes.resize(num_rows); - memcpy(new_seed_nodes.data(), rows_data, sizeof(IdxType) * num_rows); + (*new_seed_nodes).resize(num_rows); + memcpy((*new_seed_nodes).data(), rows_data, sizeof(IdxType) * num_rows); } return std::make_pair( diff --git a/src/array/cpu/rowwise_sampling.cc b/src/array/cpu/rowwise_sampling.cc index 911f69b42371..222a009341f9 100644 --- a/src/array/cpu/rowwise_sampling.cc +++ b/src/array/cpu/rowwise_sampling.cc @@ -229,7 +229,7 @@ template < DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes> std::pair CSRRowWiseSamplingFused( CSRMatrix mat, IdArray rows, IdArray seed_mapping, - std::vector& new_seed_nodes, int64_t num_samples, + std::vector* new_seed_nodes, int64_t num_samples, NDArray prob_or_mask, bool replace) { // If num_samples is -1, select all neighbors without replacement. replace = (replace && num_samples != -1); @@ -245,53 +245,53 @@ std::pair CSRRowWiseSamplingFused( template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template COOMatrix CSRRowWisePerEtypeSampling( @@ -354,7 +354,7 @@ template COOMatrix CSRRowWiseSamplingUniform( template std::pair CSRRowWiseSamplingUniformFused( CSRMatrix mat, IdArray rows, IdArray seed_mapping, - std::vector& new_seed_nodes, int64_t num_samples, bool replace) { + std::vector* new_seed_nodes, int64_t num_samples, bool replace) { // If num_samples is -1, select all neighbors without replacement. replace = (replace && num_samples != -1); auto num_picks_fn = @@ -367,16 +367,16 @@ std::pair CSRRowWiseSamplingUniformFused( template std::pair CSRRowWiseSamplingUniformFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); template std::pair CSRRowWiseSamplingUniformFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); template std::pair CSRRowWiseSamplingUniformFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); template std::pair CSRRowWiseSamplingUniformFused( - CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); template COOMatrix CSRRowWisePerEtypeSamplingUniform( diff --git a/src/graph/sampling/neighbor/neighbor.cc b/src/graph/sampling/neighbor/neighbor.cc index ae9da862eb98..10bb56f742fe 100644 --- a/src/graph/sampling/neighbor/neighbor.cc +++ b/src/graph/sampling/neighbor/neighbor.cc @@ -26,11 +26,11 @@ namespace sampling { template void ExcludeCertainEdgesFused( - std::vector& sampled_graphs, std::vector& induced_edges, - std::vector& sampled_coo_rows, + std::vector* sampled_graphs, std::vector* induced_edges, + std::vector* sampled_coo_rows, const std::vector& exclude_edges, std::vector* weights = nullptr) { - int etypes = sampled_graphs.size(); + int etypes = (*sampled_graphs).size(); std::vector remain_induced_edges(etypes); std::vector remain_indptrs(etypes); std::vector remain_indices(etypes); @@ -38,8 +38,8 @@ void ExcludeCertainEdgesFused( std::vector remain_weights(etypes); for (int etype = 0; etype < etypes; ++etype) { if (exclude_edges[etype].GetSize() == 0 || - sampled_graphs[etype].num_rows == 0) { - remain_induced_edges[etype] = induced_edges[etype]; + (*sampled_graphs)[etype].num_rows == 0) { + remain_induced_edges[etype] = (*induced_edges)[etype]; if (weights) remain_weights[etype] = (*weights)[etype]; continue; } @@ -47,10 +47,10 @@ void ExcludeCertainEdgesFused( ? (*weights)[etype]->dtype : DGLDataType{kDGLFloat, 8 * sizeof(float), 1}; ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "weights", { - IdType* indptr = sampled_graphs[etype].indptr.Ptr(); - IdType* indices = sampled_graphs[etype].indices.Ptr(); - IdType* coo_rows = sampled_coo_rows[etype].Ptr(); - IdType* induced_edges_data = induced_edges[etype].Ptr(); + IdType* indptr = (*sampled_graphs)[etype].indptr.Ptr(); + IdType* indices = (*sampled_graphs)[etype].indices.Ptr(); + IdType* coo_rows = (*sampled_coo_rows)[etype].Ptr(); + IdType* induced_edges_data = (*induced_edges)[etype].Ptr(); FloatType* weights_data = weights && (*weights)[etype]->shape[0] ? (*weights)[etype].Ptr() : nullptr; @@ -60,7 +60,7 @@ void ExcludeCertainEdgesFused( exclude_edges[etype].Ptr() + exclude_edges_len); const IdType* exclude_edges_data = exclude_edges[etype].Ptr(); IdType outIndices = 0; - for (IdType row = 0; row < sampled_graphs[etype].indptr->shape[0] - 1; + for (IdType row = 0; row < (*sampled_graphs)[etype].indptr->shape[0] - 1; ++row) { auto tmp_row = indptr[row]; if (outIndices != indptr[row]) indptr[row] = outIndices; @@ -76,19 +76,19 @@ void ExcludeCertainEdgesFused( } } } - indptr[sampled_graphs[etype].indptr->shape[0] - 1] = outIndices; + indptr[(*sampled_graphs)[etype].indptr->shape[0] - 1] = outIndices; remain_induced_edges[etype] = - aten::IndexSelect(induced_edges[etype], 0, outIndices); + aten::IndexSelect((*induced_edges)[etype], 0, outIndices); remain_weights[etype] = weights_data ? aten::IndexSelect((*weights)[etype], 0, outIndices) : NullArray(); remain_indices[etype] = - aten::IndexSelect(sampled_graphs[etype].indices, 0, outIndices); - sampled_coo_rows[etype] = - aten::IndexSelect(sampled_coo_rows[etype], 0, outIndices); - sampled_graphs[etype] = CSRMatrix( - sampled_graphs[etype].num_rows, outIndices, - sampled_graphs[etype].indptr, remain_indices[etype], + aten::IndexSelect((*sampled_graphs)[etype].indices, 0, outIndices); + (*sampled_coo_rows)[etype] = + aten::IndexSelect((*sampled_coo_rows)[etype], 0, outIndices); + (*sampled_graphs)[etype] = CSRMatrix( + (*sampled_graphs)[etype].num_rows, outIndices, + (*sampled_graphs)[etype].indptr, remain_indices[etype], remain_induced_edges[etype]); }); } @@ -342,7 +342,7 @@ template std::tuple, std::vector> SampleNeighborsFused( const HeteroGraphPtr hg, const std::vector& nodes, - std::vector& mapping, const std::vector& fanouts, + const std::vector& mapping, const std::vector& fanouts, EdgeDir dir, const std::vector& prob_or_mask, const std::vector& exclude_edges, bool replace) { CHECK_EQ(nodes.size(), hg->NumVertexTypes()) @@ -393,14 +393,14 @@ SampleNeighborsFused( // therefore two diffrent mappings and node vectors are needed sampled_graph = sampling_fn( hg->GetCSRMatrix(etype), nodes_ntype, mapping[src_vtype], - new_nodes_vec[src_vtype], fanouts[etype], prob_or_mask[etype], + &new_nodes_vec[src_vtype], fanouts[etype], prob_or_mask[etype], replace); break; case SparseFormat::kCSC: CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; sampled_graph = sampling_fn( hg->GetCSCMatrix(etype), nodes_ntype, mapping[dst_vtype], - new_nodes_vec[dst_vtype], fanouts[etype], prob_or_mask[etype], + &new_nodes_vec[dst_vtype], fanouts[etype], prob_or_mask[etype], replace); break; default: @@ -419,7 +419,7 @@ SampleNeighborsFused( if (!exclude_edges.empty()) { ExcludeCertainEdgesFused( - sampled_graphs, induced_edges, sampled_coo_rows, exclude_edges); + &sampled_graphs, &induced_edges, &sampled_coo_rows, exclude_edges); for (size_t i = 0; i < hg->NumEdgeTypes(); i++) { if (sampled_graphs[i].data.defined()) induced_edges[i] = std::move(sampled_graphs[i].data); @@ -561,6 +561,18 @@ SampleNeighborsFused( return std::make_tuple(new_graph, induced_edges, induced_vertices); } +template std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr, const std::vector&, + const std::vector&, const std::vector&, EdgeDir, + const std::vector&, const std::vector&, bool); + +template std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr, const std::vector&, + const std::vector&, const std::vector&, EdgeDir, + const std::vector&, const std::vector&, bool); + HeteroSubgraph SampleNeighborsEType( const HeteroGraphPtr hg, const IdArray nodes, const std::vector& eid2etype_offset, From 2afa96dec4b1250139d15d454c7010fa0aea742b Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Wed, 28 Jun 2023 14:39:39 +0000 Subject: [PATCH 03/10] Fix neighbor.py lint issues --- python/dgl/sampling/neighbor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/dgl/sampling/neighbor.py b/python/dgl/sampling/neighbor.py index d96372ca39f6..a33a4305a734 100644 --- a/python/dgl/sampling/neighbor.py +++ b/python/dgl/sampling/neighbor.py @@ -6,7 +6,7 @@ from .. import backend as F, ndarray as nd, utils from .._ffi.function import _init_api -from ..base import DGLError, EID, NID +from ..base import DGLError, EID from ..heterograph import DGLBlock, DGLGraph from .utils import EidExcluder @@ -478,7 +478,7 @@ def _sample_neighbors( if F.device_type(g.device) != "cpu": raise DGLError("Only cpu is supported in fused sampling") - if mapping == None: + if mapping is None: mapping = {} mapping_name = "__mapping" + str(os.getpid()) if mapping_name not in mapping.keys(): From 7a345c5b4e8636c0faf6410ff5235472074eb53e Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Thu, 29 Jun 2023 13:12:17 +0000 Subject: [PATCH 04/10] Create Windows path with atomic operation --- src/array/cpu/concurrent_id_hash_map.cc | 22 ++++++++++++++++++++++ src/array/cpu/concurrent_id_hash_map.h | 3 +++ src/graph/sampling/neighbor/neighbor.cc | 8 +++++--- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/array/cpu/concurrent_id_hash_map.cc b/src/array/cpu/concurrent_id_hash_map.cc index 46e444eb9ebf..2f9bde7c5dce 100644 --- a/src/array/cpu/concurrent_id_hash_map.cc +++ b/src/array/cpu/concurrent_id_hash_map.cc @@ -223,5 +223,27 @@ ConcurrentIdHashMap::AttemptInsertAt(int64_t pos, IdType key) { template class ConcurrentIdHashMap; template class ConcurrentIdHashMap; +template +bool BoolCompareAndSwap(IdType* ptr) { +#ifdef _MSC_VER + if (sizeof(IdType) == 4) { + return _InterlockedCompareExchange(reinterpret_cast(ptr), 0, -1) == + -1; + } else if (sizeof(IdType) == 8) { + return _InterlockedCompareExchange64( + reinterpret_cast(ptr), 0, -1) == -1; + } else { + LOG(FATAL) << "ID can only be int32 or int64"; + } +#elif __GNUC__ // _MSC_VER + return __sync_bool_compare_and_swap(ptr, -1, 0); +#else // _MSC_VER +#error "CompareAndSwap is not supported on this platform." +#endif // _MSC_VER +} + +template bool BoolCompareAndSwap(int32_t*); +template bool BoolCompareAndSwap(int64_t*); + } // namespace aten } // namespace dgl diff --git a/src/array/cpu/concurrent_id_hash_map.h b/src/array/cpu/concurrent_id_hash_map.h index db09ee902130..aa3a32a2fdcb 100644 --- a/src/array/cpu/concurrent_id_hash_map.h +++ b/src/array/cpu/concurrent_id_hash_map.h @@ -195,6 +195,9 @@ class ConcurrentIdHashMap { IdType mask_; }; +template +bool BoolCompareAndSwap(IdType* ptr); + } // namespace aten } // namespace dgl diff --git a/src/graph/sampling/neighbor/neighbor.cc b/src/graph/sampling/neighbor/neighbor.cc index 10bb56f742fe..5393a200123d 100644 --- a/src/graph/sampling/neighbor/neighbor.cc +++ b/src/graph/sampling/neighbor/neighbor.cc @@ -15,6 +15,7 @@ #include #include +#include "../../../array/cpu/concurrent_id_hash_map.h" #include "../../../c_api_common.h" #include "../../unit_graph.h" @@ -438,7 +439,7 @@ SampleNeighborsFused( (dir == EdgeDir::kIn) ? src_vtype : dst_vtype; if (sampled_graphs[etype].num_cols != 0) { auto num_cols = sampled_graphs[etype].num_cols; - const int num_threads_col = runtime::compute_num_threads(0, num_cols, 1); + int num_threads_col = runtime::compute_num_threads(0, num_cols, 1); std::vector global_prefix_col(num_threads_col + 1, 0); std::vector> src_nodes_local(num_threads_col); IdType* mapping_data_dst = mapping[lhs_node_type].Ptr(); @@ -446,6 +447,7 @@ SampleNeighborsFused( #pragma omp parallel num_threads(num_threads_col) { const int thread_id = omp_get_thread_num(); + num_threads_col = omp_get_num_threads(); const int64_t start_i = thread_id * (num_cols / num_threads_col) + @@ -458,8 +460,8 @@ SampleNeighborsFused( assert(thread_id + 1 < num_threads_col || end_i == num_cols); for (int64_t i = start_i; i < end_i; ++i) { int64_t picked_idx = cdata[i]; - bool spot_claimed = __sync_bool_compare_and_swap( - &mapping_data_dst[picked_idx], -1, 0); + bool spot_claimed = + BoolCompareAndSwap(&mapping_data_dst[picked_idx]); if (spot_claimed) src_nodes_local[thread_id].push_back(picked_idx); } global_prefix_col[thread_id + 1] = src_nodes_local[thread_id].size(); From e1494d705df9cb14523359f9c92c91e05ba0e0da Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Fri, 30 Jun 2023 11:21:43 +0000 Subject: [PATCH 05/10] Add backend checks for PyTorch only --- python/dgl/dataloading/neighbor_sampler.py | 6 ++++- python/dgl/sampling/neighbor.py | 8 ++++--- tests/python/common/sampling/test_sampling.py | 24 ++++++++++++------- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/python/dgl/dataloading/neighbor_sampler.py b/python/dgl/dataloading/neighbor_sampler.py index 3243ee1d4dd6..137d01f4d2a7 100644 --- a/python/dgl/dataloading/neighbor_sampler.py +++ b/python/dgl/dataloading/neighbor_sampler.py @@ -149,7 +149,11 @@ def __init__( def sample_blocks(self, g, seed_nodes, exclude_eids=None): output_nodes = seed_nodes blocks = [] - if F.device_type(g.device) == "cpu" and self.fused: + if ( + F.device_type(g.device) == "cpu" + and F.backend_name == "pytorch" + and self.fused + ): if self.g != g: self.mapping = {} self.g = g diff --git a/python/dgl/sampling/neighbor.py b/python/dgl/sampling/neighbor.py index a33a4305a734..4d721205c34c 100644 --- a/python/dgl/sampling/neighbor.py +++ b/python/dgl/sampling/neighbor.py @@ -475,8 +475,10 @@ def _sample_neighbors( raise DGLError( "distributed training not supported in fused sampling" ) - if F.device_type(g.device) != "cpu": - raise DGLError("Only cpu is supported in fused sampling") + if F.device_type(g.device) != "cpu" or F.backend_name != "pytorch": + raise DGLError( + "Only PyTorch backend and cpu is supported in fused sampling" + ) if mapping is None: mapping = {} @@ -500,7 +502,7 @@ def _sample_neighbors( for mapping_vector, src_nodes in zip( mapping[mapping_name], induced_nodes ): - mapping_vector[F.from_dgl_nd(src_nodes).type(torch.int64)] = -1 + mapping_vector[F.from_dgl_nd(src_nodes).type(F.int64)] = -1 new_ntypes = (g.ntypes, g.ntypes) ret = DGLBlock(subgidx, new_ntypes, g.etypes) diff --git a/tests/python/common/sampling/test_sampling.py b/tests/python/common/sampling/test_sampling.py index 650e502f57e3..6413c40000a9 100644 --- a/tests/python/common/sampling/test_sampling.py +++ b/tests/python/common/sampling/test_sampling.py @@ -1147,7 +1147,7 @@ def _test3(): def test_sample_neighbors_noprob(): _test_sample_neighbors(False, None, False) - if F._default_context_str != "gpu": + if F._default_context_str != "gpu" and F.backend_name == "pytorch": _test_sample_neighbors(False, None, True) # _test_sample_neighbors(True) @@ -1158,7 +1158,7 @@ def test_sample_labors_noprob(): def test_sample_neighbors_prob(): _test_sample_neighbors(False, "prob", False) - if F._default_context_str != "gpu": + if F._default_context_str != "gpu" and F.backend_name == "pytorch": _test_sample_neighbors(False, "prob", True) # _test_sample_neighbors(True) @@ -1169,7 +1169,7 @@ def test_sample_labors_prob(): def test_sample_neighbors_outedge(): _test_sample_neighbors_outedge(False, False) - if F._default_context_str != "gpu": + if F._default_context_str != "gpu" and F.backend_name == "pytorch": _test_sample_neighbors_outedge(False, True) # _test_sample_neighbors_outedge(True) @@ -1206,8 +1206,10 @@ def test_sample_neighbors_topk_outedge(): @pytest.mark.parametrize("fused", [False, True]) def test_sample_neighbors_with_0deg(fused): - if fused and F._default_context_str == "gpu": - pytest.skip("Fused sampling doesn't support GPU.") + if fused and ( + F._default_context_str == "gpu" or F.backend_name != "pytorch" + ): + pytest.skip("Fused sampling support CPU with backend PyTorch.") g = dgl.graph(([], []), num_nodes=5).to(F.ctx()) sg = dgl.sampling.sample_neighbors( g, @@ -1593,8 +1595,10 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction): @pytest.mark.parametrize("dtype", ["int32", "int64"]) @pytest.mark.parametrize("fused", [False, True]) def test_sample_neighbors_exclude_edges_heteroG(dtype, fused): - if fused and F._default_context_str == "gpu": - pytest.skip("Fused sampling doesn't support GPU.") + if fused and ( + F._default_context_str == "gpu" or F.backend_name != "pytorch" + ): + pytest.skip("Fused sampling support CPU with backend PyTorch.") d_i_d_u_nodes = F.zerocopy_from_numpy( np.unique(np.random.randint(300, size=100, dtype=dtype)) ) @@ -1753,8 +1757,10 @@ def contain_edge(g, sg, etype, u, v): @pytest.mark.parametrize("dtype", ["int32", "int64"]) @pytest.mark.parametrize("fused", [False, True]) def test_sample_neighbors_exclude_edges_homoG(dtype, fused): - if fused and F._default_context_str == "gpu": - pytest.skip("Fused sampling doesn't support GPU.") + if fused and ( + F._default_context_str == "gpu" or F.backend_name != "pytorch" + ): + pytest.skip("Fused sampling support CPU with backend PyTorch.") u_nodes = F.zerocopy_from_numpy( np.unique(np.random.randint(300, size=100, dtype=dtype)) ) From 805f7863a8857d0d47ce40d241924f938331642c Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Fri, 30 Jun 2023 12:25:29 +0000 Subject: [PATCH 06/10] Fix test_sample_neighbors_mask --- tests/python/common/sampling/test_sampling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/common/sampling/test_sampling.py b/tests/python/common/sampling/test_sampling.py index 6413c40000a9..812b8b1335a3 100644 --- a/tests/python/common/sampling/test_sampling.py +++ b/tests/python/common/sampling/test_sampling.py @@ -1183,7 +1183,8 @@ def test_sample_neighbors_outedge(): ) def test_sample_neighbors_mask(): _test_sample_neighbors(False, "mask", False) - _test_sample_neighbors(False, "mask", True) + if F._default_context_str != "gpu" and F.backend_name == "pytorch": + _test_sample_neighbors(False, "mask", True) @unittest.skipIf( From a4349ebfd9a84e2312c063a8c0bab884748dd2dd Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Mon, 3 Jul 2023 13:35:58 +0000 Subject: [PATCH 07/10] Update fused neighbor sampling dispatch --- python/dgl/dataloading/neighbor_sampler.py | 86 ++++++++++++---------- python/dgl/sampling/neighbor.py | 10 ++- 2 files changed, 57 insertions(+), 39 deletions(-) diff --git a/python/dgl/dataloading/neighbor_sampler.py b/python/dgl/dataloading/neighbor_sampler.py index 137d01f4d2a7..9388fe7b4c72 100644 --- a/python/dgl/dataloading/neighbor_sampler.py +++ b/python/dgl/dataloading/neighbor_sampler.py @@ -1,6 +1,7 @@ """Data loading components for neighbor sampling""" from .. import backend as F from ..base import EID, NID +from ..heterograph import DGLGraph from ..transforms import to_block from .base import BlockSampler @@ -149,44 +150,53 @@ def __init__( def sample_blocks(self, g, seed_nodes, exclude_eids=None): output_nodes = seed_nodes blocks = [] - if ( - F.device_type(g.device) == "cpu" - and F.backend_name == "pytorch" - and self.fused - ): - if self.g != g: - self.mapping = {} - self.g = g - for fanout in reversed(self.fanouts): - block = g.sample_neighbors( - seed_nodes, - fanout, - edge_dir=self.edge_dir, - prob=self.prob, - replace=self.replace, - output_device=self.output_device, - fused=True, - exclude_edges=exclude_eids, - mapping=self.mapping, - ) - seed_nodes = block.srcdata[NID] - blocks.insert(0, block) - else: - for fanout in reversed(self.fanouts): - frontier = g.sample_neighbors( - seed_nodes, - fanout, - edge_dir=self.edge_dir, - prob=self.prob, - replace=self.replace, - output_device=self.output_device, - exclude_edges=exclude_eids, - ) - eid = frontier.edata[EID] - block = to_block(frontier, seed_nodes) - block.edata[EID] = eid - seed_nodes = block.srcdata[NID] - blocks.insert(0, block) + + if self.fused: + cpu = F.device_type(g.device) == "cpu" + if type(seed_nodes) is dict: + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and type(g) == DGLGraph and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + fused=True, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + + for fanout in reversed(self.fanouts): + frontier = g.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + eid = frontier.edata[EID] + block = to_block(frontier, seed_nodes) + block.edata[EID] = eid + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) return seed_nodes, output_nodes, blocks diff --git a/python/dgl/sampling/neighbor.py b/python/dgl/sampling/neighbor.py index 4d721205c34c..0fa9a9a3a6c9 100644 --- a/python/dgl/sampling/neighbor.py +++ b/python/dgl/sampling/neighbor.py @@ -475,7 +475,15 @@ def _sample_neighbors( raise DGLError( "distributed training not supported in fused sampling" ) - if F.device_type(g.device) != "cpu" or F.backend_name != "pytorch": + cpu = F.device_type(g.device) == "cpu" + if type(nodes) is dict: + for ntype in list(nodes.keys()): + if not cpu: + break + cpu = cpu and F.device_type(nodes[ntype].device) == "cpu" + else: + cpu = cpu and F.device_type(nodes.device) == "cpu" + if not cpu or F.backend_name != "pytorch": raise DGLError( "Only PyTorch backend and cpu is supported in fused sampling" ) From 326ba0e49fe0f0769e817edc4992edad2120be4c Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Mon, 3 Jul 2023 13:51:02 +0000 Subject: [PATCH 08/10] Replace type() with isinstance() --- python/dgl/dataloading/neighbor_sampler.py | 4 ++-- python/dgl/sampling/neighbor.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/dgl/dataloading/neighbor_sampler.py b/python/dgl/dataloading/neighbor_sampler.py index 9388fe7b4c72..6cb30f98fc69 100644 --- a/python/dgl/dataloading/neighbor_sampler.py +++ b/python/dgl/dataloading/neighbor_sampler.py @@ -153,7 +153,7 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None): if self.fused: cpu = F.device_type(g.device) == "cpu" - if type(seed_nodes) is dict: + if isinstance(seed_nodes, dict): for ntype in list(seed_nodes.keys()): if not cpu: break @@ -162,7 +162,7 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None): ) else: cpu = cpu and F.device_type(seed_nodes.device) == "cpu" - if cpu and type(g) == DGLGraph and F.backend_name == "pytorch": + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": if self.g != g: self.mapping = {} self.g = g diff --git a/python/dgl/sampling/neighbor.py b/python/dgl/sampling/neighbor.py index 0fa9a9a3a6c9..b409b0e3a906 100644 --- a/python/dgl/sampling/neighbor.py +++ b/python/dgl/sampling/neighbor.py @@ -476,7 +476,7 @@ def _sample_neighbors( "distributed training not supported in fused sampling" ) cpu = F.device_type(g.device) == "cpu" - if type(nodes) is dict: + if isinstance(nodes, dict): for ntype in list(nodes.keys()): if not cpu: break From ce0d545b32f66df8dada76a37e151f5a898ad1c2 Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Mon, 17 Jul 2023 13:30:50 +0000 Subject: [PATCH 09/10] Create sample_neighbors_fused() python function --- python/dgl/dataloading/neighbor_sampler.py | 4 +- python/dgl/sampling/neighbor.py | 138 +++++++++++++++--- tests/python/common/sampling/test_sampling.py | 123 +++++----------- 3 files changed, 161 insertions(+), 104 deletions(-) diff --git a/python/dgl/dataloading/neighbor_sampler.py b/python/dgl/dataloading/neighbor_sampler.py index 6cb30f98fc69..2a0c37f417b0 100644 --- a/python/dgl/dataloading/neighbor_sampler.py +++ b/python/dgl/dataloading/neighbor_sampler.py @@ -167,14 +167,12 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None): self.mapping = {} self.g = g for fanout in reversed(self.fanouts): - block = g.sample_neighbors( + block = g.sample_neighbors_fused( seed_nodes, fanout, edge_dir=self.edge_dir, prob=self.prob, replace=self.replace, - output_device=self.output_device, - fused=True, exclude_edges=exclude_eids, mapping=self.mapping, ) diff --git a/python/dgl/sampling/neighbor.py b/python/dgl/sampling/neighbor.py index b409b0e3a906..66bc692d3d9e 100644 --- a/python/dgl/sampling/neighbor.py +++ b/python/dgl/sampling/neighbor.py @@ -13,6 +13,7 @@ __all__ = [ "sample_etype_neighbors", "sample_neighbors", + "sample_neighbors_fused", "sample_neighbors_biased", "select_topk", ] @@ -218,8 +219,6 @@ def sample_neighbors( _dist_training=False, exclude_edges=None, output_device=None, - fused=False, - mapping=None, ): """Sample neighboring edges of the given nodes and return the induced subgraph. @@ -288,18 +287,6 @@ def sample_neighbors( output_device : Framework-specific device context object, optional The output device. Default is the same as the input graph. - fused : bool, optional - Enables faster version of NeighborSampler that is also compacting output graph, - returning a computational block. Requires nodes to be unique - - (Default: False) - - mapping : dictionary, optional - Used by fused version of sample_neighbors. To avoid constant data allocation - provide empty dictionary ({}) that will be allocated once with proper data and reused - by each function call - - (Default: None) Returns ------- DGLGraph @@ -379,7 +366,123 @@ def sample_neighbors( copy_ndata=copy_ndata, copy_edata=copy_edata, exclude_edges=exclude_edges, - fused=fused, + ) + else: + frontier = _sample_neighbors( + g, + nodes, + fanout, + edge_dir=edge_dir, + prob=prob, + replace=replace, + copy_ndata=copy_ndata, + copy_edata=copy_edata, + ) + if exclude_edges is not None: + eid_excluder = EidExcluder(exclude_edges) + frontier = eid_excluder(frontier) + return frontier if output_device is None else frontier.to(output_device) + + +def sample_neighbors_fused( + g, + nodes, + fanout, + edge_dir="in", + prob=None, + replace=False, + copy_ndata=True, + copy_edata=True, + exclude_edges=None, + mapping=None, +): + """Sample neighboring edges of the given nodes and return the induced subgraph. + + For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges + will be randomly chosen. The graph returned will then contain all the nodes in the + original graph, but only the sampled edges. Nodes will be renumbered starting from id 0, + which would be new node id of first seed node. + + Parameters + ---------- + g : DGLGraph + The graph. Can be either on CPU or GPU. + nodes : tensor or dict + Node IDs to sample neighbors from. + + This argument can take a single ID tensor or a dictionary of node types and ID tensors. + If a single tensor is given, the graph must only have one type of nodes. + fanout : int or dict[etype, int] + The number of edges to be sampled for each node on each edge type. + + This argument can take a single int or a dictionary of edge types and ints. + If a single int is given, DGL will sample this number of edges for each node for + every edge type. + + If -1 is given for a single edge type, all the neighboring edges with that edge + type and non-zero probability will be selected. + edge_dir : str, optional + Determines whether to sample inbound or outbound edges. + + Can take either ``in`` for inbound edges or ``out`` for outbound edges. + prob : str, optional + Feature name used as the (unnormalized) probabilities associated with each + neighboring edge of a node. The feature must have only one element for each + edge. + + The features must be non-negative floats or boolean. Otherwise, the result + will be undefined. + exclude_edges: tensor or dict + Edge IDs to exclude during sampling neighbors for the seed nodes. + + This argument can take a single ID tensor or a dictionary of edge types and ID tensors. + If a single tensor is given, the graph must only have one type of nodes. + replace : bool, optional + If True, sample with replacement. + copy_ndata: bool, optional + If True, the node features of the new graph are copied from + the original graph. If False, the new graph will not have any + node features. + + (Default: True) + copy_edata: bool, optional + If True, the edge features of the new graph are copied from + the original graph. If False, the new graph will not have any + edge features. + + (Default: False) + + mapping : dictionary, optional + Used by fused version of NeighborSampler. To avoid constant data allocation + provide empty dictionary ({}) that will be allocated once with proper data and reused + by each function call + + (Default: None) + Returns + ------- + DGLGraph + A sampled subgraph containing only the sampled neighboring edges. + + Notes + ----- + If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as + the node or edge features of the original graph and the new graph. + As a result, users should avoid performing in-place operations + on the node features of the new graph to avoid feature corruption. + + """ + if not g.is_pinned(): + frontier = _sample_neighbors( + g, + nodes, + fanout, + edge_dir=edge_dir, + prob=prob, + replace=replace, + copy_ndata=copy_ndata, + copy_edata=copy_edata, + exclude_edges=exclude_edges, + fused=True, mapping=mapping, ) else: @@ -392,13 +495,13 @@ def sample_neighbors( replace=replace, copy_ndata=copy_ndata, copy_edata=copy_edata, - fused=fused, + fused=True, mapping=mapping, ) if exclude_edges is not None: eid_excluder = EidExcluder(exclude_edges) frontier = eid_excluder(frontier) - return frontier if output_device is None else frontier.to(output_device) + return frontier def _sample_neighbors( @@ -569,6 +672,7 @@ def _sample_neighbors( DGLGraph.sample_neighbors = utils.alias_func(sample_neighbors) +DGLGraph.sample_neighbors_fused = utils.alias_func(sample_neighbors_fused) def sample_neighbors_biased( diff --git a/tests/python/common/sampling/test_sampling.py b/tests/python/common/sampling/test_sampling.py index 812b8b1335a3..1964eb80c6d0 100644 --- a/tests/python/common/sampling/test_sampling.py +++ b/tests/python/common/sampling/test_sampling.py @@ -7,6 +7,11 @@ import numpy as np import pytest +sample_neighbors_fusing_mode = { + True: dgl.sampling.sample_neighbors_fused, + False: dgl.sampling.sample_neighbors, +} + def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None): traces = F.asnumpy(traces) @@ -559,8 +564,8 @@ def _test_sample_neighbors(hypersparse, prob, fused): g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False) def _test1(p, replace): - subg = dgl.sampling.sample_neighbors( - g, [0, 1], -1, prob=p, replace=replace, fused=fused + subg = sample_neighbors_fusing_mode[fused]( + g, [0, 1], -1, prob=p, replace=replace ) if not fused: assert subg.num_nodes() == g.num_nodes() @@ -579,8 +584,8 @@ def _test1(p, replace): assert uv == uv_ans for i in range(10): - subg = dgl.sampling.sample_neighbors( - g, [0, 1], 2, prob=p, replace=replace, fused=fused + subg = sample_neighbors_fusing_mode[fused]( + g, [0, 1], 2, prob=p, replace=replace ) if not fused: assert subg.num_nodes() == g.num_nodes() @@ -608,8 +613,8 @@ def _test1(p, replace): _test1(prob, False) # w/o replacement, uniform def _test2(p, replace): # fanout > #neighbors - subg = dgl.sampling.sample_neighbors( - g, [0, 2], -1, prob=p, replace=replace, fused=fused + subg = sample_neighbors_fusing_mode[fused]( + g, [0, 2], -1, prob=p, replace=replace ) if not fused: assert subg.num_nodes() == g.num_nodes() @@ -628,8 +633,8 @@ def _test2(p, replace): # fanout > #neighbors assert uv == uv_ans for i in range(10): - subg = dgl.sampling.sample_neighbors( - g, [0, 2], 2, prob=p, replace=replace, fused=fused + subg = sample_neighbors_fusing_mode[fused]( + g, [0, 2], 2, prob=p, replace=replace ) if not fused: assert subg.num_nodes() == g.num_nodes() @@ -655,13 +660,8 @@ def _test2(p, replace): # fanout > #neighbors _test2(prob, False) # w/o replacement, uniform def _test3(p, replace): - subg = dgl.sampling.sample_neighbors( - hg, - {"user": [0, 1], "game": 0}, - -1, - prob=p, - replace=replace, - fused=fused, + subg = sample_neighbors_fusing_mode[fused]( + hg, {"user": [0, 1], "game": 0}, -1, prob=p, replace=replace ) if not fused: assert len(subg.ntypes) == 3 @@ -674,13 +674,8 @@ def _test3(p, replace): assert subg["flips"].num_edges() == 0 for i in range(10): - subg = dgl.sampling.sample_neighbors( - hg, - {"user": [0, 1], "game": 0}, - 2, - prob=p, - replace=replace, - fused=fused, + subg = sample_neighbors_fusing_mode[fused]( + hg, {"user": [0, 1], "game": 0}, 2, prob=p, replace=replace ) if not fused: assert len(subg.ntypes) == 3 @@ -697,12 +692,11 @@ def _test3(p, replace): # test different fanouts for different relations for i in range(10): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( hg, {"user": [0, 1], "game": 0, "coin": 0}, {"follow": 1, "play": 2, "liked-by": 0, "flips": -1}, replace=True, - fused=fused, ) if not fused: assert len(subg.ntypes) == 3 @@ -833,8 +827,8 @@ def _test_sample_neighbors_outedge(hypersparse, fused): g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True) def _test1(p, replace): - subg = dgl.sampling.sample_neighbors( - g, [0, 1], -1, prob=p, replace=replace, edge_dir="out", fused=fused + subg = sample_neighbors_fusing_mode[fused]( + g, [0, 1], -1, prob=p, replace=replace, edge_dir="out" ) if not fused: assert subg.num_nodes() == g.num_nodes() @@ -854,14 +848,8 @@ def _test1(p, replace): assert uv == uv_ans for i in range(10): - subg = dgl.sampling.sample_neighbors( - g, - [0, 1], - 2, - prob=p, - replace=replace, - edge_dir="out", - fused=fused, + subg = sample_neighbors_fusing_mode[fused]( + g, [0, 1], 2, prob=p, replace=replace, edge_dir="out" ) if not fused: assert subg.num_nodes() == g.num_nodes() @@ -889,8 +877,8 @@ def _test1(p, replace): _test1("prob", False) # w/o replacement def _test2(p, replace): # fanout > #neighbors - subg = dgl.sampling.sample_neighbors( - g, [0, 2], -1, prob=p, replace=replace, edge_dir="out", fused=fused + subg = sample_neighbors_fusing_mode[fused]( + g, [0, 2], -1, prob=p, replace=replace, edge_dir="out" ) if not fused: assert subg.num_nodes() == g.num_nodes() @@ -909,14 +897,8 @@ def _test2(p, replace): # fanout > #neighbors assert uv == uv_ans for i in range(10): - subg = dgl.sampling.sample_neighbors( - g, - [0, 2], - 2, - prob=p, - replace=replace, - edge_dir="out", - fused=fused, + subg = sample_neighbors_fusing_mode[fused]( + g, [0, 2], 2, prob=p, replace=replace, edge_dir="out" ) if not fused: assert subg.num_nodes() == g.num_nodes() @@ -945,14 +927,13 @@ def _test2(p, replace): # fanout > #neighbors _test2("prob", False) # w/o replacement def _test3(p, replace): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( hg, {"user": [0, 1], "game": 0}, -1, prob=p, replace=replace, edge_dir="out", - fused=fused, ) if not fused: @@ -966,14 +947,13 @@ def _test3(p, replace): assert subg["flips"].num_edges() == 0 for i in range(10): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( hg, {"user": [0, 1], "game": 0}, 2, prob=p, replace=replace, edge_dir="out", - fused=fused, ) if not fused: assert len(subg.ntypes) == 3 @@ -1212,40 +1192,20 @@ def test_sample_neighbors_with_0deg(fused): ): pytest.skip("Fused sampling support CPU with backend PyTorch.") g = dgl.graph(([], []), num_nodes=5).to(F.ctx()) - sg = dgl.sampling.sample_neighbors( - g, - F.tensor([1, 2], dtype=F.int64), - 2, - edge_dir="in", - replace=False, - fused=fused, + sg = sample_neighbors_fusing_mode[fused]( + g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=False ) assert sg.num_edges() == 0 - sg = dgl.sampling.sample_neighbors( - g, - F.tensor([1, 2], dtype=F.int64), - 2, - edge_dir="in", - replace=True, - fused=fused, + sg = sample_neighbors_fusing_mode[fused]( + g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=True ) assert sg.num_edges() == 0 - sg = dgl.sampling.sample_neighbors( - g, - F.tensor([1, 2], dtype=F.int64), - 2, - edge_dir="out", - replace=False, - fused=fused, + sg = sample_neighbors_fusing_mode[fused]( + g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=False ) assert sg.num_edges() == 0 - sg = dgl.sampling.sample_neighbors( - g, - F.tensor([1, 2], dtype=F.int64), - 2, - edge_dir="out", - replace=True, - fused=fused, + sg = sample_neighbors_fusing_mode[fused]( + g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=True ) assert sg.num_edges() == 0 @@ -1672,7 +1632,7 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype, fused): ("drug", "treats", "disease"): excluded_d_t_d_edges, } - sg = dgl.sampling.sample_neighbors( + sg = sample_neighbors_fusing_mode[fused]( g, { "drug": sampled_drug_node, @@ -1681,7 +1641,6 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype, fused): }, sampled_amount, exclude_edges=excluded_edges, - fused=fused, ) if fused: @@ -1784,12 +1743,8 @@ def test_sample_neighbors_exclude_edges_homoG(dtype, fused): excluded_nodes_U = g_edges[U][b_idx:e_idx] excluded_nodes_V = g_edges[V][b_idx:e_idx] - sg = dgl.sampling.sample_neighbors( - g, - sampled_node, - sampled_amount, - exclude_edges=excluded_edges, - fused=fused, + sg = sample_neighbors_fusing_mode[fused]( + g, sampled_node, sampled_amount, exclude_edges=excluded_edges ) if fused: From 8ae7ab69457567fd97b1e6ab5c386ff2c12e0bd0 Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Wed, 19 Jul 2023 09:18:45 +0000 Subject: [PATCH 10/10] Update benchmark fused sampling --- benchmarks/benchmarks/api/bench_fused_sample_neighbors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py b/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py index 91fc6705d5eb..f695a5c575ba 100644 --- a/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py +++ b/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py @@ -25,15 +25,15 @@ def track_time(graph_name, format, seed_nodes_num, fanout): # dry run for i in range(3): - dgl.sampling.sample_neighbors( - graph, seed_nodes, fanout, edge_dir=edge_dir, fused=True + dgl.sampling.sample_neighbors_fused( + graph, seed_nodes, fanout, edge_dir=edge_dir ) # timing with utils.Timer() as t: for i in range(50): - dgl.sampling.sample_neighbors( - graph, seed_nodes, fanout, edge_dir=edge_dir, fused=True + dgl.sampling.sample_neighbors_fused( + graph, seed_nodes, fanout, edge_dir=edge_dir ) return t.elapsed_secs / 50