Skip to content

Commit

Permalink
fix dataloader parameter names.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jun 26, 2024
1 parent 54a594a commit 71f4a6c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def test_DataLoader():
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
@pytest.mark.parametrize("overlap_graph_fetch", [True, False])
@pytest.mark.parametrize("num_cached_edges", [0, 1024])
@pytest.mark.parametrize("threshold", [1, 3])
@pytest.mark.parametrize("num_gpu_cached_edges", [0, 1024])
@pytest.mark.parametrize("gpu_cache_threshold", [1, 3])
def test_gpu_sampling_DataLoader(
sampler_name,
enable_feature_fetch,
overlap_feature_fetch,
overlap_graph_fetch,
num_cached_edges,
threshold,
num_gpu_cached_edges,
gpu_cache_threshold,
):
N = 40
B = 4
Expand Down Expand Up @@ -98,8 +98,8 @@ def test_gpu_sampling_DataLoader(
datapipe,
overlap_feature_fetch=overlap_feature_fetch,
overlap_graph_fetch=overlap_graph_fetch,
num_cached_edges=num_cached_edges,
threshold=threshold,
num_gpu_cached_edges=num_gpu_cached_edges,
gpu_cache_threshold=gpu_cache_threshold,
)
bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch)
if overlap_graph_fetch:
Expand Down

0 comments on commit 71f4a6c

Please sign in to comment.