Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Labor dependent template specialization. #7220

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions graphbolt/include/graphbolt/continuous_seed.h
Original file line number Diff line number Diff line change
@@ -92,6 +92,31 @@ class continuous_seed {
#endif // __CUDA_ARCH__
};

class single_seed {
uint64_t seed_;

public:
/* implicit */ single_seed(const int64_t seed) : seed_(seed) {} // NOLINT

single_seed(torch::Tensor seed_arr)
: seed_(seed_arr.data_ptr<int64_t>()[0]) {}

#ifdef __CUDACC__
__device__ inline float uniform(const uint64_t id) const {
const uint64_t kCurandSeed = 999961; // Could be any random number.
curandStatePhilox4_32_10_t rng;
curand_init(kCurandSeed, seed_, id, &rng);
return curand_uniform(&rng);
}
#else
inline float uniform(const uint64_t id) const {
pcg32 ng0(seed_, id);
std::uniform_real_distribution<float> uni;
return uni(ng0);
}
#endif // __CUDA_ARCH__
};

} // namespace graphbolt

#endif // GRAPHBOLT_CONTINUOUS_SEED_H_
31 changes: 21 additions & 10 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
@@ -17,7 +17,11 @@
namespace graphbolt {
namespace sampling {

enum SamplerType { NEIGHBOR, LABOR };
enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT };

constexpr bool is_labor(SamplerType S) {
return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT;
}

template <SamplerType S>
struct SamplerArgs;
@@ -27,6 +31,13 @@ struct SamplerArgs<SamplerType::NEIGHBOR> {};

template <>
struct SamplerArgs<SamplerType::LABOR> {
const torch::Tensor& indices;
single_seed random_seed;
int64_t num_nodes;
};

template <>
struct SamplerArgs<SamplerType::LABOR_DEPENDENT> {
const torch::Tensor& indices;
continuous_seed random_seed;
int64_t num_nodes;
@@ -555,12 +566,12 @@ int64_t Pick(
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);

template <typename PickedType>
int64_t Pick(
template <SamplerType S, typename PickedType>
std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);

template <typename PickedType>
int64_t TemporalPick(
@@ -619,13 +630,13 @@ int64_t TemporalPickByEtype(
PickedType* picked_data_ptr);

template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize = 1024>
int64_t LaborPick(
bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
typename PickedType, int StackSize = 1024>
std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);

} // namespace sampling
} // namespace graphbolt
80 changes: 46 additions & 34 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
#include <limits>
#include <numeric>
#include <tuple>
#include <type_traits>
#include <vector>

#include "./macro.h"
@@ -660,26 +661,37 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}

if (layer) {
SamplerArgs<SamplerType::LABOR> args = [&] {
if (random_seed.has_value()) {
return SamplerArgs<SamplerType::LABOR>{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
args));
if (random_seed.has_value() && random_seed->numel() >= 2) {
SamplerArgs<SamplerType::LABOR_DEPENDENT> args{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
} else {
auto args = [&] {
if (random_seed.has_value() && random_seed->numel() == 1) {
return SamplerArgs<SamplerType::LABOR>{
indices_, random_seed.value(), NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
}
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
@@ -1297,7 +1309,7 @@ int64_t TemporalPick(
}
return picked_indices.numel();
}
if constexpr (S == SamplerType::LABOR) {
if constexpr (is_labor(S)) {
return Pick(
offset, num_neighbors, fanout, replace, options, masked_prob, args,
picked_data_ptr);
@@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype(
return pick_offset;
}

template <typename PickedType>
int64_t Pick(
template <SamplerType S, typename PickedType>
std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
if (fanout == 0) return 0;
if (probs_or_mask.has_value()) {
if (fanout < 0) {
@@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) {
return rem * (one - std::pow(one - u, one / n));
}

template <typename T>
template <typename T, typename seed_t>
inline T jth_sorted_uniform_random(
continuous_seed seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
seed_t seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
const T u = seed.uniform(t + j * c);
// https://mathematica.stackexchange.com/a/256707
rem -= invcdf(u, n, rem);
@@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random(
* should be put. Enough memory space should be allocated in advance.
*/
template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize>
inline int64_t LaborPick(
bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
typename PickedType, int StackSize>
inline std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) {
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
@@ -1504,8 +1516,8 @@ inline int64_t LaborPick(
}
AT_DISPATCH_INDEX_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] {
const index_t* local_indices_data =
args.indices.data_ptr<index_t>() + offset;
const auto local_indices_data =
reinterpret_cast<index_t*>(args.indices.data_ptr()) + offset;
if constexpr (Replace) {
// [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the