diff --git a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py index b123773294d8..47976c834612 100644 --- a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py @@ -6,8 +6,7 @@ from dgl import graphbolt as gb -def _test_query_and_replace(policy1, policy2, keys): - offset = 111111 +def _test_query_and_replace(policy1, policy2, keys, offset): # Testing query_and_replace equivalence to query and then replace. ( _, @@ -60,7 +59,8 @@ def _test_query_and_replace(policy1, policy2, keys): @pytest.mark.parametrize("feature_size", [2, 16]) @pytest.mark.parametrize("num_parts", [1, 2, None]) @pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru", "clock"]) -def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): +@pytest.mark.parametrize("offset", [0, 1111111]) +def test_feature_cache(offsets, dtype, feature_size, num_parts, policy, offset): cache_size = 32 * ( torch.get_num_threads() if num_parts is None else num_parts ) @@ -80,7 +80,9 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): reader_fn = lambda keys: a[keys] keys = torch.tensor([0, 1]) - values, missing_index, missing_keys, missing_offsets = cache.query(keys) + values, missing_index, missing_keys, missing_offsets = cache.query( + keys, offset + ) if not offsets: missing_offsets = None assert torch.equal( @@ -89,17 +91,21 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): ) missing_values = a[missing_keys] - cache.replace(missing_keys, missing_values, missing_offsets) + cache.replace(missing_keys, missing_values, missing_offsets, offset) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys]) + assert torch.equal( + cache2.query_and_replace(keys, reader_fn, offset), a[keys] + ) - _test_query_and_replace(policy1, policy2, keys) + _test_query_and_replace(policy1, policy2, keys, offset) pin_memory = F._default_context_str == "gpu" keys = torch.arange(1, 33, pin_memory=pin_memory) - values, missing_index, missing_keys, missing_offsets = cache.query(keys) + values, missing_index, missing_keys, missing_offsets = cache.query( + keys, offset + ) if not offsets: missing_offsets = None assert torch.equal( @@ -109,38 +115,48 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): assert not pin_memory or values.is_pinned() missing_values = a[missing_keys] - cache.replace(missing_keys, missing_values, missing_offsets) + cache.replace(missing_keys, missing_values, missing_offsets, offset) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys]) + assert torch.equal( + cache2.query_and_replace(keys, reader_fn, offset), a[keys] + ) - _test_query_and_replace(policy1, policy2, keys) + _test_query_and_replace(policy1, policy2, keys, offset) - values, missing_index, missing_keys, missing_offsets = cache.query(keys) + values, missing_index, missing_keys, missing_offsets = cache.query( + keys, offset + ) if not offsets: missing_offsets = None assert torch.equal(missing_keys.flip([0]), torch.tensor([])) missing_values = a[missing_keys] - cache.replace(missing_keys, missing_values, missing_offsets) + cache.replace(missing_keys, missing_values, missing_offsets, offset) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys]) + assert torch.equal( + cache2.query_and_replace(keys, reader_fn, offset), a[keys] + ) - _test_query_and_replace(policy1, policy2, keys) + _test_query_and_replace(policy1, policy2, keys, offset) - values, missing_index, missing_keys, missing_offsets = cache.query(keys) + values, missing_index, missing_keys, missing_offsets = cache.query( + keys, offset + ) if not offsets: missing_offsets = None assert torch.equal(missing_keys.flip([0]), torch.tensor([])) missing_values = a[missing_keys] - cache.replace(missing_keys, missing_values, missing_offsets) + cache.replace(missing_keys, missing_values, missing_offsets, offset) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys]) + assert torch.equal( + cache2.query_and_replace(keys, reader_fn, offset), a[keys] + ) - _test_query_and_replace(policy1, policy2, keys) + _test_query_and_replace(policy1, policy2, keys, offset) assert cache.miss_rate == cache2.miss_rate