Skip to content

Commit

Permalink
add offset test.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 20, 2024
1 parent 430f0f6 commit 0152e92
Showing 1 changed file with 35 additions and 19 deletions.
54 changes: 35 additions & 19 deletions tests/python/pytorch/graphbolt/impl/test_feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from dgl import graphbolt as gb


def _test_query_and_replace(policy1, policy2, keys):
offset = 111111
def _test_query_and_replace(policy1, policy2, keys, offset):
# Testing query_and_replace equivalence to query and then replace.
(
_,
Expand Down Expand Up @@ -60,7 +59,8 @@ def _test_query_and_replace(policy1, policy2, keys):
@pytest.mark.parametrize("feature_size", [2, 16])
@pytest.mark.parametrize("num_parts", [1, 2, None])
@pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru", "clock"])
def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
@pytest.mark.parametrize("offset", [0, 1111111])
def test_feature_cache(offsets, dtype, feature_size, num_parts, policy, offset):
cache_size = 32 * (
torch.get_num_threads() if num_parts is None else num_parts
)
Expand All @@ -80,7 +80,9 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
reader_fn = lambda keys: a[keys]

keys = torch.tensor([0, 1])
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
values, missing_index, missing_keys, missing_offsets = cache.query(
keys, offset
)
if not offsets:
missing_offsets = None
assert torch.equal(
Expand All @@ -89,17 +91,21 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
)

missing_values = a[missing_keys]
cache.replace(missing_keys, missing_values, missing_offsets)
cache.replace(missing_keys, missing_values, missing_offsets, offset)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys])
assert torch.equal(
cache2.query_and_replace(keys, reader_fn, offset), a[keys]
)

_test_query_and_replace(policy1, policy2, keys)
_test_query_and_replace(policy1, policy2, keys, offset)

pin_memory = F._default_context_str == "gpu"

keys = torch.arange(1, 33, pin_memory=pin_memory)
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
values, missing_index, missing_keys, missing_offsets = cache.query(
keys, offset
)
if not offsets:
missing_offsets = None
assert torch.equal(
Expand All @@ -109,38 +115,48 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
assert not pin_memory or values.is_pinned()

missing_values = a[missing_keys]
cache.replace(missing_keys, missing_values, missing_offsets)
cache.replace(missing_keys, missing_values, missing_offsets, offset)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys])
assert torch.equal(
cache2.query_and_replace(keys, reader_fn, offset), a[keys]
)

_test_query_and_replace(policy1, policy2, keys)
_test_query_and_replace(policy1, policy2, keys, offset)

values, missing_index, missing_keys, missing_offsets = cache.query(keys)
values, missing_index, missing_keys, missing_offsets = cache.query(
keys, offset
)
if not offsets:
missing_offsets = None
assert torch.equal(missing_keys.flip([0]), torch.tensor([]))

missing_values = a[missing_keys]
cache.replace(missing_keys, missing_values, missing_offsets)
cache.replace(missing_keys, missing_values, missing_offsets, offset)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys])
assert torch.equal(
cache2.query_and_replace(keys, reader_fn, offset), a[keys]
)

_test_query_and_replace(policy1, policy2, keys)
_test_query_and_replace(policy1, policy2, keys, offset)

values, missing_index, missing_keys, missing_offsets = cache.query(keys)
values, missing_index, missing_keys, missing_offsets = cache.query(
keys, offset
)
if not offsets:
missing_offsets = None
assert torch.equal(missing_keys.flip([0]), torch.tensor([]))

missing_values = a[missing_keys]
cache.replace(missing_keys, missing_values, missing_offsets)
cache.replace(missing_keys, missing_values, missing_offsets, offset)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys])
assert torch.equal(
cache2.query_and_replace(keys, reader_fn, offset), a[keys]
)

_test_query_and_replace(policy1, policy2, keys)
_test_query_and_replace(policy1, policy2, keys, offset)

assert cache.miss_rate == cache2.miss_rate

Expand Down

0 comments on commit 0152e92

Please sign in to comment.