diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 9cdf390f2df2..981c6efefcef 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -53,15 +53,15 @@ def test_DataLoader(): @pytest.mark.parametrize("enable_feature_fetch", [True, False]) @pytest.mark.parametrize("overlap_feature_fetch", [True, False]) @pytest.mark.parametrize("overlap_graph_fetch", [True, False]) -@pytest.mark.parametrize("num_cached_edges", [0, 1024]) -@pytest.mark.parametrize("threshold", [1, 3]) +@pytest.mark.parametrize("num_gpu_cached_edges", [0, 1024]) +@pytest.mark.parametrize("gpu_cache_threshold", [1, 3]) def test_gpu_sampling_DataLoader( sampler_name, enable_feature_fetch, overlap_feature_fetch, overlap_graph_fetch, - num_cached_edges, - threshold, + num_gpu_cached_edges, + gpu_cache_threshold, ): N = 40 B = 4 @@ -98,8 +98,8 @@ def test_gpu_sampling_DataLoader( datapipe, overlap_feature_fetch=overlap_feature_fetch, overlap_graph_fetch=overlap_graph_fetch, - num_cached_edges=num_cached_edges, - threshold=threshold, + num_gpu_cached_edges=num_gpu_cached_edges, + gpu_cache_threshold=gpu_cache_threshold, ) bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch) if overlap_graph_fetch: