Skip to content

Commit

Permalink
[GraphBolt][CUDA] Add non_blocking option to CopyTo.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 27, 2024
1 parent b99db08 commit 550eb54
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 19 deletions.
64 changes: 55 additions & 9 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

from .internal_utils import recursive_apply
from .internal_utils import (
get_nonproperty_attributes,
recursive_apply,
recursive_apply_reduce_all,
)

__all__ = [
"CANONICAL_ETYPE_DELIMITER",
Expand Down Expand Up @@ -306,10 +310,32 @@ def seed_type_str_to_ntypes(seed_type, seed_size):
return ntypes


def apply_to(x, device):
def apply_to(x, device, non_blocking=False):
"""Apply `to` function to object x only if it has `to`."""

return x.to(device) if hasattr(x, "to") else x
if device == "pinned" and hasattr(x, "pin_memory"):
return x.pin_memory()
if not hasattr(x, "to"):
return x
if not non_blocking:
return x.to(device)
# The copy is non blocking only if the objects are pinned.
assert x.is_pinned(), f"{x} should be pinned."
return x.to(device, non_blocking=True)


def is_object_pinned(obj):
"""Recursively check all members of the object and return True if only if
all are pinned."""

for attr in get_nonproperty_attributes(obj):
member_result = recursive_apply_reduce_all(
getattr(obj, attr),
lambda x: x is None or x.is_pinned(),
)
if not member_result:
return False
return True


@functional_datapipe("copy_to")
Expand All @@ -334,17 +360,22 @@ class CopyTo(IterDataPipe):
The DataPipe.
device : torch.device
The PyTorch CUDA device.
non_blocking : bool
Whether the copy should be performed without blocking. All elements have
to be already in pinned system memory if enabled. Default is False.
"""

def __init__(self, datapipe, device):
def __init__(self, datapipe, device, non_blocking=False):
super().__init__()
self.datapipe = datapipe
self.device = device
self.device = torch.device(device)
self.non_blocking = non_blocking

def __iter__(self):
for data in self.datapipe:
data = recursive_apply(data, apply_to, self.device)
yield data
yield recursive_apply(
data, apply_to, self.device, self.non_blocking
)


@functional_datapipe("mark_end")
Expand Down Expand Up @@ -460,7 +491,9 @@ def __init__(self, indptr: torch.Tensor, indices: torch.Tensor):
def __repr__(self) -> str:
return _csc_format_base_str(self)

def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
def to(
self, device: torch.device, non_blocking=False
) -> None: # pylint: disable=invalid-name
"""Copy `CSCFormatBase` to the specified device using reflection."""

for attr in dir(self):
Expand All @@ -470,12 +503,25 @@ def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: apply_to(x, device)
getattr(self, attr),
apply_to,
device,
non_blocking=non_blocking,
),
)

return self

def pin_memory(self):
"""Copy `SampledSubgraph` to the pinned memory using reflection."""

return self.to("pinned")

def is_pinned(self) -> bool:
"""Check whether `SampledSubgraph` is pinned using reflection."""

return is_object_pinned(self)


def _csc_format_base_str(csc_format_base: CSCFormatBase) -> str:
final_str = "CSCFormatBase("
Expand Down
27 changes: 22 additions & 5 deletions python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

import torch

from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr
from .base import (
apply_to,
CSCFormatBase,
etype_str_to_tuple,
expand_indptr,
is_object_pinned,
)
from .internal_utils import (
get_attributes,
get_nonproperty_attributes,
Expand Down Expand Up @@ -350,20 +356,31 @@ def to_pyg_data(self):
)
return pyg_data

def to(self, device: torch.device): # pylint: disable=invalid-name
def to(
self, device: torch.device, non_blocking=False
): # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection."""

def _to(x):
return x.to(device) if hasattr(x, "to") else x
copy_fn = lambda x: apply_to(x, device, non_blocking=non_blocking)

transfer_attrs = get_nonproperty_attributes(self)

for attr in transfer_attrs:
# Only copy member variables.
setattr(self, attr, recursive_apply(getattr(self, attr), _to))
setattr(self, attr, recursive_apply(getattr(self, attr), copy_fn))

return self

def pin_memory(self):
"""Copy `MiniBatch` to the pinned memory using reflection."""

return self.to("pinned")

def is_pinned(self) -> bool:
"""Check whether `SampledSubgraph` is pinned using reflection."""

return is_object_pinned(self)


def _minibatch_str(minibatch: MiniBatch) -> str:
final_str = ""
Expand Down
20 changes: 18 additions & 2 deletions python/dgl/graphbolt/sampled_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CSCFormatBase,
etype_str_to_tuple,
expand_indptr,
is_object_pinned,
isin,
)

Expand Down Expand Up @@ -232,7 +233,9 @@ def exclude_edges(
)
return calling_class(*_slice_subgraph(self, index))

def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
def to(
self, device: torch.device, non_blocking=False
) -> None: # pylint: disable=invalid-name
"""Copy `SampledSubgraph` to the specified device using reflection."""

for attr in dir(self):
Expand All @@ -242,12 +245,25 @@ def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: apply_to(x, device)
getattr(self, attr),
apply_to,
device,
non_blocking=non_blocking,
),
)

return self

def pin_memory(self):
"""Copy `SampledSubgraph` to the pinned memory using reflection."""

return self.to("pinned")

def is_pinned(self) -> bool:
"""Check whether `SampledSubgraph` is pinned using reflection."""

return is_object_pinned(self)


def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
indptr = node_pair.indptr
Expand Down
13 changes: 10 additions & 3 deletions tests/python/pytorch/graphbolt/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,26 @@
from . import gb_test_utils


@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
@unittest.skipIf(F._default_context_str != "gpu", "CopyTo needs GPU to test")
@pytest.mark.parametrize("non_blocking", [False, True])
def test_CopyTo(non_blocking):
item_sampler = gb.ItemSampler(
gb.ItemSet(torch.arange(20), names="seeds"), 4
)
if non_blocking:
item_sampler = item_sampler.transform(lambda x: x.pin_memory())

# Invoke CopyTo via class constructor.
dp = gb.CopyTo(item_sampler, "cuda")
for data in dp:
assert data.seeds.device.type == "cuda"

dp = gb.CopyTo(item_sampler, "cuda", non_blocking)
for data in dp:
assert data.seeds.device.type == "cuda"

# Invoke CopyTo via functional form.
dp = item_sampler.copy_to("cuda")
dp = item_sampler.copy_to("cuda", non_blocking)
for data in dp:
assert data.seeds.device.type == "cuda"

Expand Down

0 comments on commit 550eb54

Please sign in to comment.