Skip to content

Commit

Permalink
[GraphBolt] Labor dependent template specialization. (#7220)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Mar 18, 2024
1 parent 74c5e31 commit a2c5472
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 44 deletions.
25 changes: 25 additions & 0 deletions graphbolt/include/graphbolt/continuous_seed.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,31 @@ class continuous_seed {
#endif // __CUDA_ARCH__
};

class single_seed {
uint64_t seed_;

public:
/* implicit */ single_seed(const int64_t seed) : seed_(seed) {} // NOLINT

single_seed(torch::Tensor seed_arr)
: seed_(seed_arr.data_ptr<int64_t>()[0]) {}

#ifdef __CUDACC__
__device__ inline float uniform(const uint64_t id) const {
const uint64_t kCurandSeed = 999961; // Could be any random number.
curandStatePhilox4_32_10_t rng;
curand_init(kCurandSeed, seed_, id, &rng);
return curand_uniform(&rng);
}
#else
inline float uniform(const uint64_t id) const {
pcg32 ng0(seed_, id);
std::uniform_real_distribution<float> uni;
return uni(ng0);
}
#endif // __CUDA_ARCH__
};

} // namespace graphbolt

#endif // GRAPHBOLT_CONTINUOUS_SEED_H_
31 changes: 21 additions & 10 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
namespace graphbolt {
namespace sampling {

enum SamplerType { NEIGHBOR, LABOR };
enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT };

constexpr bool is_labor(SamplerType S) {
return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT;
}

template <SamplerType S>
struct SamplerArgs;
Expand All @@ -27,6 +31,13 @@ struct SamplerArgs<SamplerType::NEIGHBOR> {};

template <>
struct SamplerArgs<SamplerType::LABOR> {
const torch::Tensor& indices;
single_seed random_seed;
int64_t num_nodes;
};

template <>
struct SamplerArgs<SamplerType::LABOR_DEPENDENT> {
const torch::Tensor& indices;
continuous_seed random_seed;
int64_t num_nodes;
Expand Down Expand Up @@ -555,12 +566,12 @@ int64_t Pick(
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);

template <typename PickedType>
int64_t Pick(
template <SamplerType S, typename PickedType>
std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);

template <typename PickedType>
int64_t TemporalPick(
Expand Down Expand Up @@ -619,13 +630,13 @@ int64_t TemporalPickByEtype(
PickedType* picked_data_ptr);

template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize = 1024>
int64_t LaborPick(
bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
typename PickedType, int StackSize = 1024>
std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);

} // namespace sampling
} // namespace graphbolt
Expand Down
80 changes: 46 additions & 34 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <limits>
#include <numeric>
#include <tuple>
#include <type_traits>
#include <vector>

#include "./macro.h"
Expand Down Expand Up @@ -660,26 +661,37 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}

if (layer) {
SamplerArgs<SamplerType::LABOR> args = [&] {
if (random_seed.has_value()) {
return SamplerArgs<SamplerType::LABOR>{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
args));
if (random_seed.has_value() && random_seed->numel() >= 2) {
SamplerArgs<SamplerType::LABOR_DEPENDENT> args{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
} else {
auto args = [&] {
if (random_seed.has_value() && random_seed->numel() == 1) {
return SamplerArgs<SamplerType::LABOR>{
indices_, random_seed.value(), NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
}
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
Expand Down Expand Up @@ -1297,7 +1309,7 @@ int64_t TemporalPick(
}
return picked_indices.numel();
}
if constexpr (S == SamplerType::LABOR) {
if constexpr (is_labor(S)) {
return Pick(
offset, num_neighbors, fanout, replace, options, masked_prob, args,
picked_data_ptr);
Expand Down Expand Up @@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype(
return pick_offset;
}

template <typename PickedType>
int64_t Pick(
template <SamplerType S, typename PickedType>
std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
if (fanout == 0) return 0;
if (probs_or_mask.has_value()) {
if (fanout < 0) {
Expand Down Expand Up @@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) {
return rem * (one - std::pow(one - u, one / n));
}

template <typename T>
template <typename T, typename seed_t>
inline T jth_sorted_uniform_random(
continuous_seed seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
seed_t seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
const T u = seed.uniform(t + j * c);
// https://mathematica.stackexchange.com/a/256707
rem -= invcdf(u, n, rem);
Expand Down Expand Up @@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random(
* should be put. Enough memory space should be allocated in advance.
*/
template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize>
inline int64_t LaborPick(
bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
typename PickedType, int StackSize>
inline std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) {
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
Expand All @@ -1504,8 +1516,8 @@ inline int64_t LaborPick(
}
AT_DISPATCH_INDEX_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] {
const index_t* local_indices_data =
args.indices.data_ptr<index_t>() + offset;
const auto local_indices_data =
reinterpret_cast<index_t*>(args.indices.data_ptr()) + offset;
if constexpr (Replace) {
// [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the
Expand Down

0 comments on commit a2c5472

Please sign in to comment.