From 550eb5469157a0bfaf7b8bd24e7ce39309cfdcb3 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 27 Jul 2024 19:42:15 -0400 Subject: [PATCH] [GraphBolt][CUDA] Add `non_blocking` option to `CopyTo`. --- python/dgl/graphbolt/base.py | 64 ++++++++++++++++++--- python/dgl/graphbolt/minibatch.py | 27 +++++++-- python/dgl/graphbolt/sampled_subgraph.py | 20 ++++++- tests/python/pytorch/graphbolt/test_base.py | 13 ++++- 4 files changed, 105 insertions(+), 19 deletions(-) diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index 1c27c8cff8c1..02532b6b0b73 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -20,7 +20,11 @@ from torch.utils.data import functional_datapipe from torchdata.datapipes.iter import IterDataPipe -from .internal_utils import recursive_apply +from .internal_utils import ( + get_nonproperty_attributes, + recursive_apply, + recursive_apply_reduce_all, +) __all__ = [ "CANONICAL_ETYPE_DELIMITER", @@ -306,10 +310,32 @@ def seed_type_str_to_ntypes(seed_type, seed_size): return ntypes -def apply_to(x, device): +def apply_to(x, device, non_blocking=False): """Apply `to` function to object x only if it has `to`.""" - return x.to(device) if hasattr(x, "to") else x + if device == "pinned" and hasattr(x, "pin_memory"): + return x.pin_memory() + if not hasattr(x, "to"): + return x + if not non_blocking: + return x.to(device) + # The copy is non blocking only if the objects are pinned. + assert x.is_pinned(), f"{x} should be pinned." + return x.to(device, non_blocking=True) + + +def is_object_pinned(obj): + """Recursively check all members of the object and return True if only if + all are pinned.""" + + for attr in get_nonproperty_attributes(obj): + member_result = recursive_apply_reduce_all( + getattr(obj, attr), + lambda x: x is None or x.is_pinned(), + ) + if not member_result: + return False + return True @functional_datapipe("copy_to") @@ -334,17 +360,22 @@ class CopyTo(IterDataPipe): The DataPipe. device : torch.device The PyTorch CUDA device. + non_blocking : bool + Whether the copy should be performed without blocking. All elements have + to be already in pinned system memory if enabled. Default is False. """ - def __init__(self, datapipe, device): + def __init__(self, datapipe, device, non_blocking=False): super().__init__() self.datapipe = datapipe - self.device = device + self.device = torch.device(device) + self.non_blocking = non_blocking def __iter__(self): for data in self.datapipe: - data = recursive_apply(data, apply_to, self.device) - yield data + yield recursive_apply( + data, apply_to, self.device, self.non_blocking + ) @functional_datapipe("mark_end") @@ -460,7 +491,9 @@ def __init__(self, indptr: torch.Tensor, indices: torch.Tensor): def __repr__(self) -> str: return _csc_format_base_str(self) - def to(self, device: torch.device) -> None: # pylint: disable=invalid-name + def to( + self, device: torch.device, non_blocking=False + ) -> None: # pylint: disable=invalid-name """Copy `CSCFormatBase` to the specified device using reflection.""" for attr in dir(self): @@ -470,12 +503,25 @@ def to(self, device: torch.device) -> None: # pylint: disable=invalid-name self, attr, recursive_apply( - getattr(self, attr), lambda x: apply_to(x, device) + getattr(self, attr), + apply_to, + device, + non_blocking=non_blocking, ), ) return self + def pin_memory(self): + """Copy `SampledSubgraph` to the pinned memory using reflection.""" + + return self.to("pinned") + + def is_pinned(self) -> bool: + """Check whether `SampledSubgraph` is pinned using reflection.""" + + return is_object_pinned(self) + def _csc_format_base_str(csc_format_base: CSCFormatBase) -> str: final_str = "CSCFormatBase(" diff --git a/python/dgl/graphbolt/minibatch.py b/python/dgl/graphbolt/minibatch.py index bc2e62011ee7..7b34a8d5f1d3 100644 --- a/python/dgl/graphbolt/minibatch.py +++ b/python/dgl/graphbolt/minibatch.py @@ -5,7 +5,13 @@ import torch -from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr +from .base import ( + apply_to, + CSCFormatBase, + etype_str_to_tuple, + expand_indptr, + is_object_pinned, +) from .internal_utils import ( get_attributes, get_nonproperty_attributes, @@ -350,20 +356,31 @@ def to_pyg_data(self): ) return pyg_data - def to(self, device: torch.device): # pylint: disable=invalid-name + def to( + self, device: torch.device, non_blocking=False + ): # pylint: disable=invalid-name """Copy `MiniBatch` to the specified device using reflection.""" - def _to(x): - return x.to(device) if hasattr(x, "to") else x + copy_fn = lambda x: apply_to(x, device, non_blocking=non_blocking) transfer_attrs = get_nonproperty_attributes(self) for attr in transfer_attrs: # Only copy member variables. - setattr(self, attr, recursive_apply(getattr(self, attr), _to)) + setattr(self, attr, recursive_apply(getattr(self, attr), copy_fn)) return self + def pin_memory(self): + """Copy `MiniBatch` to the pinned memory using reflection.""" + + return self.to("pinned") + + def is_pinned(self) -> bool: + """Check whether `SampledSubgraph` is pinned using reflection.""" + + return is_object_pinned(self) + def _minibatch_str(minibatch: MiniBatch) -> str: final_str = "" diff --git a/python/dgl/graphbolt/sampled_subgraph.py b/python/dgl/graphbolt/sampled_subgraph.py index 1e4f238e1367..d46535115170 100644 --- a/python/dgl/graphbolt/sampled_subgraph.py +++ b/python/dgl/graphbolt/sampled_subgraph.py @@ -10,6 +10,7 @@ CSCFormatBase, etype_str_to_tuple, expand_indptr, + is_object_pinned, isin, ) @@ -232,7 +233,9 @@ def exclude_edges( ) return calling_class(*_slice_subgraph(self, index)) - def to(self, device: torch.device) -> None: # pylint: disable=invalid-name + def to( + self, device: torch.device, non_blocking=False + ) -> None: # pylint: disable=invalid-name """Copy `SampledSubgraph` to the specified device using reflection.""" for attr in dir(self): @@ -242,12 +245,25 @@ def to(self, device: torch.device) -> None: # pylint: disable=invalid-name self, attr, recursive_apply( - getattr(self, attr), lambda x: apply_to(x, device) + getattr(self, attr), + apply_to, + device, + non_blocking=non_blocking, ), ) return self + def pin_memory(self): + """Copy `SampledSubgraph` to the pinned memory using reflection.""" + + return self.to("pinned") + + def is_pinned(self) -> bool: + """Check whether `SampledSubgraph` is pinned using reflection.""" + + return is_object_pinned(self) + def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids): indptr = node_pair.indptr diff --git a/tests/python/pytorch/graphbolt/test_base.py b/tests/python/pytorch/graphbolt/test_base.py index a65264369ed3..3ca506314a49 100644 --- a/tests/python/pytorch/graphbolt/test_base.py +++ b/tests/python/pytorch/graphbolt/test_base.py @@ -12,19 +12,26 @@ from . import gb_test_utils -@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") -def test_CopyTo(): +@unittest.skipIf(F._default_context_str != "gpu", "CopyTo needs GPU to test") +@pytest.mark.parametrize("non_blocking", [False, True]) +def test_CopyTo(non_blocking): item_sampler = gb.ItemSampler( gb.ItemSet(torch.arange(20), names="seeds"), 4 ) + if non_blocking: + item_sampler = item_sampler.transform(lambda x: x.pin_memory()) # Invoke CopyTo via class constructor. dp = gb.CopyTo(item_sampler, "cuda") for data in dp: assert data.seeds.device.type == "cuda" + dp = gb.CopyTo(item_sampler, "cuda", non_blocking) + for data in dp: + assert data.seeds.device.type == "cuda" + # Invoke CopyTo via functional form. - dp = item_sampler.copy_to("cuda") + dp = item_sampler.copy_to("cuda", non_blocking) for data in dp: assert data.seeds.device.type == "cuda"