Skip to content

Commit

Permalink
[GraphBolt][CUDA] Inplace pin memory for Graph and TorchFeatureStore (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jan 18, 2024
1 parent 053c822 commit c864c91
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 3 deletions.
37 changes: 36 additions & 1 deletion python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def __init__(
):
super().__init__()
self._c_csc_graph = c_csc_graph
self._is_inplace_pinned = set()

def __del__(self):
# torch.Tensor.pin_memory() is not an inplace operation. To make it
# truly in-place, we need to use cudaHostRegister. Then, we need to use
# cudaHostUnregister to unpin the tensor in the destructor.
# https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
for tensor in self._is_inplace_pinned:
assert (
torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0
)

@property
def total_num_nodes(self) -> int:
Expand Down Expand Up @@ -974,9 +985,33 @@ def _pin(x):

def pin_memory_(self):
"""Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
# torch.Tensor.pin_memory() is not an inplace operation. To make it
# truly in-place, we need to use cudaHostRegister. Then, we need to use
# cudaHostUnregister to unpin the tensor in the destructor.
# https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
cudart = torch.cuda.cudart()

def _pin(x):
return x.pin_memory() if hasattr(x, "pin_memory") else x
if hasattr(x, "pin_memory_"):
x.pin_memory_()
elif (
isinstance(x, torch.Tensor)
and not x.is_pinned()
and x.device.type == "cpu"
):
assert (
x.is_contiguous()
), "Tensor pinning is only supported for contiguous tensors."
assert (
cudart.cudaHostRegister(
x.data_ptr(), x.numel() * x.element_size(), 0
)
== 0
)

self._is_inplace_pinned.add(x)

return x

self._apply_to_members(_pin)

Expand Down
31 changes: 29 additions & 2 deletions python/dgl/graphbolt/impl/torch_based_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None):
# Make sure the tensor is contiguous.
self._tensor = torch_feature.contiguous()
self._metadata = metadata
self._is_inplace_pinned = set()

def __del__(self):
# torch.Tensor.pin_memory() is not an inplace operation. To make it
# truly in-place, we need to use cudaHostRegister. Then, we need to use
# cudaHostUnregister to unpin the tensor in the destructor.
# https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
for tensor in self._is_inplace_pinned:
assert (
torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0
)

def read(self, ids: torch.Tensor = None):
"""Read the feature by index.
Expand Down Expand Up @@ -169,14 +180,30 @@ def metadata(self):

def pin_memory_(self):
"""In-place operation to copy the feature to pinned memory."""
self._tensor = self._tensor.pin_memory()
# torch.Tensor.pin_memory() is not an inplace operation. To make it
# truly in-place, we need to use cudaHostRegister. Then, we need to use
# cudaHostUnregister to unpin the tensor in the destructor.
# https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
x = self._tensor
if not x.is_pinned() and x.device.type == "cpu":
assert (
x.is_contiguous()
), "Tensor pinning is only supported for contiguous tensors."
assert (
torch.cuda.cudart().cudaHostRegister(
x.data_ptr(), x.numel() * x.element_size(), 0
)
== 0
)

self._is_inplace_pinned.add(x)

def to(self, device): # pylint: disable=invalid-name
"""Copy `TorchBasedFeature` to the specified device."""
# copy.copy is a shallow copy so it does not copy tensor memory.
self2 = copy.copy(self)
if device == "pinned":
self2.pin_memory_()
self2._tensor = self2._tensor.pin_memory()
else:
self2._tensor = self2._tensor.to(device)
return self2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1601,10 +1601,14 @@ def test_csc_sampling_graph_to_device(device):
def test_csc_sampling_graph_to_pinned_memory():
# Construct FusedCSCSamplingGraph.
graph = create_fused_csc_sampling_graph()
ptr = graph.csc_indptr.data_ptr()

# Copy to pinned_memory in-place.
graph.pin_memory_()

# Check if pinning is truly in-place.
assert graph.csc_indptr.data_ptr() == ptr

is_graph_on_device_type(graph, "cpu")
is_graph_pinned(graph)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def test_torch_based_pinned_feature(dtype, idtype, shape):
feature = gb.TorchBasedFeature(tensor)
feature.pin_memory_()

# Check if pinning is truly in-place.
assert feature._tensor.data_ptr() == tensor.data_ptr()

# Test read entire pinned feature, the result should be on cuda.
assert torch.equal(feature.read(), test_tensor_cuda)
assert feature.read().is_cuda
Expand Down

0 comments on commit c864c91

Please sign in to comment.