Skip to content

Commit

Permalink
[GraphBolt] Improve DiskBasedFeature tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 16, 2024
1 parent 8368f62 commit 7ff5d77
Showing 1 changed file with 37 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import sys
import tempfile
import unittest
from functools import partial

import backend as F

import numpy as np
import pytest
Expand All @@ -17,6 +19,9 @@ def to_on_disk_numpy(test_dir, name, t):
return path


assert_equal = partial(torch.testing.assert_close, rtol=0, atol=0)


@unittest.skipIf(
not torch.ops.graphbolt.detect_io_uring(),
reason="DiskBasedFeature is not available on this system.",
Expand All @@ -25,46 +30,58 @@ def test_disk_based_feature():
with tempfile.TemporaryDirectory() as test_dir:
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])
c = torch.randn([4111, 47])
metadata = {"max_value": 3}
path_a = to_on_disk_numpy(test_dir, "a", a)
path_b = to_on_disk_numpy(test_dir, "b", b)
path_c = to_on_disk_numpy(test_dir, "c", c)

feature_a = gb.DiskBasedFeature(path=path_a, metadata=metadata)
feature_b = gb.DiskBasedFeature(path=path_b)
feature_c = gb.DiskBasedFeature(path=path_c)

# Read the entire feature.
assert torch.equal(
feature_a.read(), torch.tensor([[1, 2, 3], [4, 5, 6]])
)
assert_equal(feature_a.read(), torch.tensor([[1, 2, 3], [4, 5, 6]]))

assert torch.equal(
assert_equal(
feature_b.read(), torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]])
)

# Test read the feature with ids.
assert torch.equal(
assert_equal(
feature_a.read(torch.tensor([0])),
torch.tensor([[1, 2, 3]]),
)
assert torch.equal(
assert_equal(
feature_b.read(torch.tensor([1])),
torch.tensor([[[4, 5], [6, 7]]]),
)

# test when the index tensor is large.
# Test reading into pin_memory
if F._default_context_str == "gpu":
res = feature_a.read(torch.tensor([0], pin_memory=True))
assert res.is_pinned()

# Test when the index tensor is large.
torch_based_feature_a = gb.TorchBasedFeature(a)
ind_a = torch.randint(low=0, high=2, size=(1, 4097))[0]
assert torch.equal(
ind_a = torch.randint(low=0, high=a.size(0), size=(4111,))
assert_equal(
feature_a.read(ind_a),
torch_based_feature_a.read(ind_a),
)
torch_based_feature_b = gb.TorchBasedFeature(b)
ind_b = torch.randint(low=0, high=2, size=(1, 4097))[0]
assert torch.equal(

# Test converting to torch_based_feature with read_into_memory()
torch_based_feature_b = feature_b.read_into_memory()
ind_b = torch.randint(low=0, high=b.size(0), size=(4111,))
assert_equal(
feature_b.read(ind_b),
torch_based_feature_b.read(ind_b),
)

# Test with larger stored feature tensor
ind_c = torch.randint(low=0, high=c.size(0), size=(4111,))
assert_equal(feature_c.read(ind_c), c[ind_c])

# Test get the size of the entire feature.
assert feature_a.size() == torch.Size([3])
assert feature_b.size() == torch.Size([2, 2])
Expand All @@ -88,8 +105,8 @@ def test_disk_based_feature():

# For windows, the file is locked by the numpy.load. We need to delete
# it before closing the temporary directory.
a = b = None
feature_a = feature_b = None
a = b = c = None
feature_a = feature_b = feature_c = None


@unittest.skipIf(
Expand All @@ -116,23 +133,21 @@ def test_disk_based_feature():
def test_more_disk_based_feature(dtype, idtype, shape, index):
if dtype == torch.complex128:
tensor = torch.complex(
torch.randint(0, 13, shape, dtype=torch.float64),
torch.randint(0, 13, shape, dtype=torch.float64),
torch.randint(0, 127, shape, dtype=torch.float64),
torch.randint(0, 127, shape, dtype=torch.float64),
)
else:
tensor = torch.randint(0, 13, shape, dtype=dtype)
tensor = torch.randint(0, 127, shape, dtype=dtype)
test_tensor = tensor.clone()
idx = torch.tensor(index)
idx = torch.tensor(index, dtype=idtype)

with tempfile.TemporaryDirectory() as test_dir:
path = to_on_disk_numpy(test_dir, "tensor", tensor)

feature = gb.DiskBasedFeature(path=path)

# Test read feature.
assert torch.equal(
feature.read(torch.tensor(idx, dtype=idtype)), test_tensor[idx]
)
assert_equal(feature.read(idx), test_tensor[idx.long()])


@unittest.skipIf(
Expand Down

0 comments on commit 7ff5d77

Please sign in to comment.