Skip to content

Commit

Permalink
Add alternative for non-pytorch runs
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-dlasalle committed May 16, 2023
1 parent 4116ebc commit 1bbf3ea
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
18 changes: 14 additions & 4 deletions src/array/cuda/rowwise_sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <curand_kernel.h>
#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h>

#include <numeric>

Expand All @@ -15,9 +16,11 @@
#include "./dgl_cub.cuh"
#include "./utils.h"

using namespace dgl::cuda;
using namespace dgl::aten::cuda;
using TensorDispatcher = dgl::runtime::TensorDispatcher;

namespace dgl {
using namespace cuda;
using namespace aten::cuda;
namespace aten {
namespace impl {

Expand Down Expand Up @@ -287,8 +290,15 @@ COOMatrix _CSRRowWiseSamplingUniform(
cudaEvent_t copyEvent;
CUDA_CALL(cudaEventCreate(&copyEvent));

auto new_len_tensor = NDArray::PinnedEmpty(
{1}, DGLDataTypeTraits<IdType>::dtype, DGLContext{kDGLCPU, 0});
NDArray new_len_tensor;
if (TensorDispatcher::Global()->IsAvailable()) {
new_len_tensor = NDArray::PinnedEmpty(
{1}, DGLDataTypeTraits<IdType>::dtype, DGLContext{kDGLCPU, 0});
} else {
// use pageable memory, it will unecessarily block but be functional
new_len_tensor = NDArray::Empty(
{1}, DGLDataTypeTraits<IdType>::dtype, DGLContext{kDGLCPU, 0});
}

// copy using the internal current stream
CUDA_CALL(cudaMemcpyAsync(
Expand Down
15 changes: 12 additions & 3 deletions src/graph/transform/cuda/cuda_to_block.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <cuda_runtime.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h>

#include <algorithm>
#include <memory>
Expand All @@ -36,6 +37,7 @@
using namespace dgl::aten;
using namespace dgl::runtime::cuda;
using namespace dgl::transform::cuda;
using TensorDispatcher = dgl::runtime::TensorDispatcher;

namespace dgl {
namespace transform {
Expand Down Expand Up @@ -178,9 +180,16 @@ struct CUDAIdsMapper {
stream);

CUDA_CALL(cudaEventCreate(&copyEvent));
new_len_tensor = NDArray::PinnedEmpty(
{num_ntypes}, DGLDataTypeTraits<int64_t>::dtype,
DGLContext{kDGLCPU, 0});
if (TensorDispatcher::Global()->IsAvailable()) {
new_len_tensor = NDArray::PinnedEmpty(
{num_ntypes}, DGLDataTypeTraits<int64_t>::dtype,
DGLContext{kDGLCPU, 0});
} else {
// use pageable memory, it will unecessarily block but be functional
new_len_tensor = NDArray::Empty(
{num_ntypes}, DGLDataTypeTraits<int64_t>::dtype,
DGLContext{kDGLCPU, 0});
}
CUDA_CALL(cudaMemcpyAsync(
new_len_tensor->data, count_lhs_device,
sizeof(*num_nodes_per_type.data()) * num_ntypes,
Expand Down

0 comments on commit 1bbf3ea

Please sign in to comment.