Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Async feature fetch refactor #7540

Merged
merged 35 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
de5c5ea
[GraphBolt] Async feature fetch refactor
mfbalin Jul 18, 2024
cc7565b
add implementations for TorchBasedFeature and GPUCachedFeature.
mfbalin Jul 18, 2024
3255817
linting
mfbalin Jul 18, 2024
7fcc4b4
remove unused device fn.
mfbalin Jul 18, 2024
83099bc
add the type of the argument.
mfbalin Jul 18, 2024
79e4e2f
add to base class too.
mfbalin Jul 18, 2024
5b2bf55
add complete disk implementation
mfbalin Jul 18, 2024
1d49dc8
add the disk tests too from GPU.
mfbalin Jul 18, 2024
10ba72c
linting
mfbalin Jul 18, 2024
9f40899
linting
mfbalin Jul 18, 2024
537b815
fix the bug in the tests.
mfbalin Jul 18, 2024
2dc96d9
add comments for test skip condition.
mfbalin Jul 18, 2024
7eb3ee4
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 19, 2024
573a031
add cpu cached feature read async and tests.
mfbalin Jul 19, 2024
108c418
make a separate test.
mfbalin Jul 19, 2024
080d69d
refine tests.
mfbalin Jul 19, 2024
5db3bba
add better docstring.
mfbalin Jul 20, 2024
17ef609
linting
mfbalin Jul 20, 2024
f06ec02
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
06983c3
fix bugs and remove prints.
mfbalin Jul 20, 2024
b623645
one more bug.
mfbalin Jul 20, 2024
80bed85
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
d805c24
use `scatter_async`.
mfbalin Jul 20, 2024
52cbd13
linting.
mfbalin Jul 20, 2024
5625d3a
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
f3f5733
Update python/dgl/graphbolt/impl/gpu_cached_feature.py
mfbalin Jul 20, 2024
a62395e
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
9161bf9
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
05f9a86
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
49283ef
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
2c19798
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
da4d20a
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
42e5194
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 20, 2024
dd636ed
Better docstring for `read_async` and `read_async_num_stages`.
mfbalin Jul 22, 2024
d9c73e9
Merge branch 'master' into gb_async_feature_fetch
mfbalin Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions python/dgl/graphbolt/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.

Parameters
----------
ids : torch.Tensor
Expand All @@ -52,21 +53,25 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.

Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
raise NotImplementedError

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.

Parameters
----------
ids_device : torch.device
Expand Down
13 changes: 9 additions & 4 deletions python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.

Parameters
----------
ids : torch.Tensor
Expand All @@ -95,14 +96,15 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.

Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
policy = self._feature._policy
Expand Down Expand Up @@ -309,7 +311,10 @@ def wait():

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.

Parameters
----------
ids_device : torch.device
Expand Down
13 changes: 9 additions & 4 deletions python/dgl/graphbolt/impl/gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.

Parameters
----------
ids : torch.Tensor
Expand All @@ -102,14 +103,15 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.

Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
values, missing_index, missing_keys = self._feature.query(ids)
Expand All @@ -136,7 +138,10 @@ def wait():

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.

Parameters
----------
ids_device : torch.device
Expand Down
26 changes: 18 additions & 8 deletions python/dgl/graphbolt/impl/torch_based_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.

Parameters
----------
ids : torch.Tensor
Expand All @@ -139,14 +140,15 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.

Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
assert self._tensor.device.type == "cpu"
Expand Down Expand Up @@ -206,7 +208,10 @@ def wait():

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.

Parameters
----------
ids_device : torch.device
Expand Down Expand Up @@ -405,6 +410,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.

Parameters
----------
ids : torch.Tensor
Expand All @@ -417,14 +423,15 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.

Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
assert torch.ops.graphbolt.detect_io_uring()
Expand Down Expand Up @@ -465,7 +472,10 @@ def wait():

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.

Parameters
----------
ids_device : torch.device
Expand Down
Loading