From 71b06733c9f40b5172b2e452726fbadbc96b976b Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 2 Mar 2024 03:16:53 +0000 Subject: [PATCH 01/10] [GraphBolt] `torch.compile()` support for `gb.expand_indptr`. --- graphbolt/src/expand_indptr.cc | 13 +++++++ graphbolt/src/python_binding.cc | 6 +++- python/dgl/graphbolt/__init__.py | 39 +++++++++++---------- python/dgl/graphbolt/base.py | 9 +++++ tests/python/pytorch/graphbolt/test_base.py | 17 +++++++++ 5 files changed, 64 insertions(+), 20 deletions(-) diff --git a/graphbolt/src/expand_indptr.cc b/graphbolt/src/expand_indptr.cc index 09ade4af0ab3..9074e33f5104 100644 --- a/graphbolt/src/expand_indptr.cc +++ b/graphbolt/src/expand_indptr.cc @@ -5,6 +5,7 @@ * @brief ExpandIndptr operators. */ #include +#include #include "./macro.h" #include "./utils.h" @@ -29,5 +30,17 @@ torch::Tensor ExpandIndptr( indptr.diff(), 0, output_size); } +TORCH_LIBRARY_IMPL(graphbolt, CPU, m) { + m.impl("expand_indptr", &ExpandIndptr); +} + +TORCH_LIBRARY_IMPL(graphbolt, CUDA, m) { + m.impl("expand_indptr", &ExpandIndptrImpl); +} + +TORCH_LIBRARY_IMPL(graphbolt, Autograd, m) { + m.impl("expand_indptr", torch::autograd::autogradNotImplementedFallback()); +} + } // namespace ops } // namespace graphbolt diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 443e5c87ad12..9510d4cf0e6a 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -24,6 +24,7 @@ namespace graphbolt { namespace sampling { TORCH_LIBRARY(graphbolt, m) { + m.impl_abstract_pystub("graphbolt", "//dgl.graphbolt"); m.class_("FusedSampledSubgraph") .def(torch::init<>()) .def_readwrite("indptr", &FusedSampledSubgraph::indptr) @@ -88,7 +89,10 @@ TORCH_LIBRARY(graphbolt, m) { m.def("isin", &IsIn); m.def("index_select", &ops::IndexSelect); m.def("index_select_csc", &ops::IndexSelectCSC); - m.def("expand_indptr", &ops::ExpandIndptr); + m.def( + "expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, int? " + "output_size) -> Tensor", + {at::Tag::pt2_compliant_tag}); m.def("set_seed", &RandomEngine::SetManualSeed); #ifdef GRAPHBOLT_USE_CUDA m.def("set_max_uva_threads", &cuda::set_max_uva_threads); diff --git a/python/dgl/graphbolt/__init__.py b/python/dgl/graphbolt/__init__.py index e9c09c975d6f..0f33fcd4f579 100644 --- a/python/dgl/graphbolt/__init__.py +++ b/python/dgl/graphbolt/__init__.py @@ -5,25 +5,6 @@ import torch from .._ffi import libinfo -from .base import * -from .minibatch import * -from .dataloader import * -from .dataset import * -from .feature_fetcher import * -from .feature_store import * -from .impl import * -from .itemset import * -from .item_sampler import * -from .minibatch_transformer import * -from .negative_sampler import * -from .sampled_subgraph import * -from .subgraph_sampler import * -from .internal import ( - compact_csc_format, - unique_and_compact, - unique_and_compact_csc_formats, -) -from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges def load_graphbolt(): @@ -53,3 +34,23 @@ def load_graphbolt(): load_graphbolt() + +from .base import * +from .minibatch import * +from .dataloader import * +from .dataset import * +from .feature_fetcher import * +from .feature_store import * +from .impl import * +from .itemset import * +from .item_sampler import * +from .minibatch_transformer import * +from .negative_sampler import * +from .sampled_subgraph import * +from .subgraph_sampler import * +from .internal import ( + compact_csc_format, + unique_and_compact, + unique_and_compact_csc_formats, +) +from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index 01906a15ad42..bafa52bce72c 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -63,6 +63,15 @@ def isin(elements, test_elements): return torch.ops.graphbolt.isin(elements, test_elements) +@torch.library.impl_abstract("graphbolt::expand_indptr") +def expand_indptr_abstract(indptr, dtype, node_ids, output_size): + if output_size is None: + output_size = torch.library.get_ctx().new_dynamic_size() + if dtype is None: + dtype = node_ids.dtype + return indptr.new_empty(output_size, dtype=dtype) + + def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None): """Converts a given indptr offset tensor to a COO format tensor. If node_ids is not given, it is assumed to be equal to diff --git a/tests/python/pytorch/graphbolt/test_base.py b/tests/python/pytorch/graphbolt/test_base.py index 5d7d6c477c33..87e83578624e 100644 --- a/tests/python/pytorch/graphbolt/test_base.py +++ b/tests/python/pytorch/graphbolt/test_base.py @@ -7,6 +7,7 @@ import dgl.graphbolt as gb import pytest import torch +from torch.testing._internal.optests import opcheck from . import gb_test_utils @@ -296,6 +297,22 @@ def test_expand_indptr(nodes, dtype): gb_result = gb.expand_indptr(indptr, dtype, nodes, indptr[-1].item()) assert torch.equal(torch_result, gb_result) + # Tests torch.compile compatibility + for output_size in [None, indptr[-1].item()]: + kwargs = {"node_ids": nodes, "output_size": output_size} + opcheck( + torch.ops.graphbolt.expand_indptr, + (indptr, dtype), + kwargs, + test_utils=[ + "test_schema", + "test_autograd_registration", + "test_faketensor", + "test_aot_dispatch_dynamic", + ], + raise_exception=True, + ) + def test_csc_format_base_representation(): csc_format_base = gb.CSCFormatBase( From 3001f423f9d511d3b9e06944da6af1637c48aaff Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 2 Mar 2024 03:21:15 +0000 Subject: [PATCH 02/10] linting 1. --- python/dgl/graphbolt/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index bafa52bce72c..f8e84d1eb954 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -65,6 +65,7 @@ def isin(elements, test_elements): @torch.library.impl_abstract("graphbolt::expand_indptr") def expand_indptr_abstract(indptr, dtype, node_ids, output_size): + """Abstract implementation of expand_indptr for torch.compile() support.""" if output_size is None: output_size = torch.library.get_ctx().new_dynamic_size() if dtype is None: From abc7c6ce20b09beca190e1d4b5a042839cf37fe8 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 2 Mar 2024 03:24:44 +0000 Subject: [PATCH 03/10] linting 2 --- python/dgl/graphbolt/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/dgl/graphbolt/__init__.py b/python/dgl/graphbolt/__init__.py index 0f33fcd4f579..48bc2628700f 100644 --- a/python/dgl/graphbolt/__init__.py +++ b/python/dgl/graphbolt/__init__.py @@ -35,6 +35,7 @@ def load_graphbolt(): load_graphbolt() +# pylint: disable=wrong-import-position from .base import * from .minibatch import * from .dataloader import * From 3a55cb7b5bafb01a430d2c40803df1081ba4c11a Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 3 Mar 2024 00:52:26 +0000 Subject: [PATCH 04/10] fix the issue by using SymInt. --- graphbolt/src/python_binding.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 9510d4cf0e6a..f79d304aad9e 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -90,8 +90,8 @@ TORCH_LIBRARY(graphbolt, m) { m.def("index_select", &ops::IndexSelect); m.def("index_select_csc", &ops::IndexSelectCSC); m.def( - "expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, int? " - "output_size) -> Tensor", + "expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, " + "SymInt? output_size) -> Tensor", {at::Tag::pt2_compliant_tag}); m.def("set_seed", &RandomEngine::SetManualSeed); #ifdef GRAPHBOLT_USE_CUDA From a0a2b45485c5886573844b592e5a069e80e1d9e0 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 3 Mar 2024 18:44:57 +0000 Subject: [PATCH 05/10] check number of graph breaks --- tests/python/pytorch/graphbolt/test_base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/python/pytorch/graphbolt/test_base.py b/tests/python/pytorch/graphbolt/test_base.py index 87e83578624e..f8aa549699cd 100644 --- a/tests/python/pytorch/graphbolt/test_base.py +++ b/tests/python/pytorch/graphbolt/test_base.py @@ -7,7 +7,6 @@ import dgl.graphbolt as gb import pytest import torch -from torch.testing._internal.optests import opcheck from . import gb_test_utils @@ -297,6 +296,9 @@ def test_expand_indptr(nodes, dtype): gb_result = gb.expand_indptr(indptr, dtype, nodes, indptr[-1].item()) assert torch.equal(torch_result, gb_result) + import torch._dynamo as dynamo + from torch.testing._internal.optests import opcheck + # Tests torch.compile compatibility for output_size in [None, indptr[-1].item()]: kwargs = {"node_ids": nodes, "output_size": output_size} @@ -313,6 +315,12 @@ def test_expand_indptr(nodes, dtype): raise_exception=True, ) + explanation = dynamo.explain(gb.expand_indptr)( + indptr, dtype, nodes, output_size + ) + expected_breaks = -1 if output_size is None else 0 + assert explanation.graph_break_count == expected_breaks + def test_csc_format_base_representation(): csc_format_base = gb.CSCFormatBase( From ede5b54ffd808056bb368952a22a4903b60cf582 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 4 Mar 2024 04:30:50 +0000 Subject: [PATCH 06/10] add ifdefs to enable compilation with older torch. --- graphbolt/src/python_binding.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index f79d304aad9e..0a2603eb0ae9 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -24,7 +24,9 @@ namespace graphbolt { namespace sampling { TORCH_LIBRARY(graphbolt, m) { +#ifdef HAS_IMPL_ABSTRACT_PYSTUB m.impl_abstract_pystub("graphbolt", "//dgl.graphbolt"); +#endif m.class_("FusedSampledSubgraph") .def(torch::init<>()) .def_readwrite("indptr", &FusedSampledSubgraph::indptr) @@ -91,8 +93,12 @@ TORCH_LIBRARY(graphbolt, m) { m.def("index_select_csc", &ops::IndexSelectCSC); m.def( "expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, " - "SymInt? output_size) -> Tensor", - {at::Tag::pt2_compliant_tag}); + "SymInt? output_size) -> Tensor" +#ifdef HAS_PT2_COMPLIANT_TAG + , + {at::Tag::pt2_compliant_tag} +#endif + ); m.def("set_seed", &RandomEngine::SetManualSeed); #ifdef GRAPHBOLT_USE_CUDA m.def("set_max_uva_threads", &cuda::set_max_uva_threads); From eb140fea087d72d32b1050b489cac11c9b30fd78 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 4 Mar 2024 05:01:29 +0000 Subject: [PATCH 07/10] guard cuda implementation with ifdef. --- graphbolt/src/expand_indptr.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/graphbolt/src/expand_indptr.cc b/graphbolt/src/expand_indptr.cc index 9074e33f5104..82b3e0649236 100644 --- a/graphbolt/src/expand_indptr.cc +++ b/graphbolt/src/expand_indptr.cc @@ -34,9 +34,11 @@ TORCH_LIBRARY_IMPL(graphbolt, CPU, m) { m.impl("expand_indptr", &ExpandIndptr); } +#ifdef GRAPHBOLT_USE_CUDA TORCH_LIBRARY_IMPL(graphbolt, CUDA, m) { m.impl("expand_indptr", &ExpandIndptrImpl); } +#endif TORCH_LIBRARY_IMPL(graphbolt, Autograd, m) { m.impl("expand_indptr", torch::autograd::autogradNotImplementedFallback()); From 90527559a4dbb1ec4dbfd7caebe1a27ab4b8da7c Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 4 Mar 2024 05:06:29 +0000 Subject: [PATCH 08/10] add guards to python code and test as well. --- python/dgl/graphbolt/base.py | 19 +++++---- tests/python/pytorch/graphbolt/test_base.py | 44 +++++++++++---------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index f8e84d1eb954..e4209396a4e2 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -4,6 +4,7 @@ from dataclasses import dataclass import torch +from torch.torch_version import TorchVersion from torch.utils.data import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -63,14 +64,16 @@ def isin(elements, test_elements): return torch.ops.graphbolt.isin(elements, test_elements) -@torch.library.impl_abstract("graphbolt::expand_indptr") -def expand_indptr_abstract(indptr, dtype, node_ids, output_size): - """Abstract implementation of expand_indptr for torch.compile() support.""" - if output_size is None: - output_size = torch.library.get_ctx().new_dynamic_size() - if dtype is None: - dtype = node_ids.dtype - return indptr.new_empty(output_size, dtype=dtype) +if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"): + + @torch.library.impl_abstract("graphbolt::expand_indptr") + def expand_indptr_abstract(indptr, dtype, node_ids, output_size): + """Abstract implementation of expand_indptr for torch.compile() support.""" + if output_size is None: + output_size = torch.library.get_ctx().new_dynamic_size() + if dtype is None: + dtype = node_ids.dtype + return indptr.new_empty(output_size, dtype=dtype) def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None): diff --git a/tests/python/pytorch/graphbolt/test_base.py b/tests/python/pytorch/graphbolt/test_base.py index f8aa549699cd..df59601f3520 100644 --- a/tests/python/pytorch/graphbolt/test_base.py +++ b/tests/python/pytorch/graphbolt/test_base.py @@ -7,6 +7,7 @@ import dgl.graphbolt as gb import pytest import torch +from torch.torch_version import TorchVersion from . import gb_test_utils @@ -299,27 +300,28 @@ def test_expand_indptr(nodes, dtype): import torch._dynamo as dynamo from torch.testing._internal.optests import opcheck - # Tests torch.compile compatibility - for output_size in [None, indptr[-1].item()]: - kwargs = {"node_ids": nodes, "output_size": output_size} - opcheck( - torch.ops.graphbolt.expand_indptr, - (indptr, dtype), - kwargs, - test_utils=[ - "test_schema", - "test_autograd_registration", - "test_faketensor", - "test_aot_dispatch_dynamic", - ], - raise_exception=True, - ) - - explanation = dynamo.explain(gb.expand_indptr)( - indptr, dtype, nodes, output_size - ) - expected_breaks = -1 if output_size is None else 0 - assert explanation.graph_break_count == expected_breaks + if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"): + # Tests torch.compile compatibility + for output_size in [None, indptr[-1].item()]: + kwargs = {"node_ids": nodes, "output_size": output_size} + opcheck( + torch.ops.graphbolt.expand_indptr, + (indptr, dtype), + kwargs, + test_utils=[ + "test_schema", + "test_autograd_registration", + "test_faketensor", + "test_aot_dispatch_dynamic", + ], + raise_exception=True, + ) + + explanation = dynamo.explain(gb.expand_indptr)( + indptr, dtype, nodes, output_size + ) + expected_breaks = -1 if output_size is None else 0 + assert explanation.graph_break_count == expected_breaks def test_csc_format_base_representation(): From 9a2350cba5f64175766a71bf33ac2d263d1c7ffb Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 4 Mar 2024 05:09:04 +0000 Subject: [PATCH 09/10] move imports under version check as well. --- tests/python/pytorch/graphbolt/test_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/pytorch/graphbolt/test_base.py b/tests/python/pytorch/graphbolt/test_base.py index df59601f3520..e83ca550786d 100644 --- a/tests/python/pytorch/graphbolt/test_base.py +++ b/tests/python/pytorch/graphbolt/test_base.py @@ -297,10 +297,10 @@ def test_expand_indptr(nodes, dtype): gb_result = gb.expand_indptr(indptr, dtype, nodes, indptr[-1].item()) assert torch.equal(torch_result, gb_result) - import torch._dynamo as dynamo - from torch.testing._internal.optests import opcheck - if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"): + import torch._dynamo as dynamo + from torch.testing._internal.optests import opcheck + # Tests torch.compile compatibility for output_size in [None, indptr[-1].item()]: kwargs = {"node_ids": nodes, "output_size": output_size} From ceb3c9513eff923540043130c7e80697188b92d0 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Tue, 5 Mar 2024 02:01:04 +0000 Subject: [PATCH 10/10] fix bug. --- graphbolt/src/python_binding.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 0a2603eb0ae9..d3adc2d0bb72 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -24,9 +24,6 @@ namespace graphbolt { namespace sampling { TORCH_LIBRARY(graphbolt, m) { -#ifdef HAS_IMPL_ABSTRACT_PYSTUB - m.impl_abstract_pystub("graphbolt", "//dgl.graphbolt"); -#endif m.class_("FusedSampledSubgraph") .def(torch::init<>()) .def_readwrite("indptr", &FusedSampledSubgraph::indptr) @@ -91,6 +88,13 @@ TORCH_LIBRARY(graphbolt, m) { m.def("isin", &IsIn); m.def("index_select", &ops::IndexSelect); m.def("index_select_csc", &ops::IndexSelectCSC); + m.def("set_seed", &RandomEngine::SetManualSeed); +#ifdef GRAPHBOLT_USE_CUDA + m.def("set_max_uva_threads", &cuda::set_max_uva_threads); +#endif +#ifdef HAS_IMPL_ABSTRACT_PYSTUB + m.impl_abstract_pystub("dgl.graphbolt.base", "//dgl.graphbolt.base"); +#endif m.def( "expand_indptr(Tensor indptr, ScalarType dtype, Tensor? node_ids, " "SymInt? output_size) -> Tensor" @@ -99,10 +103,6 @@ TORCH_LIBRARY(graphbolt, m) { {at::Tag::pt2_compliant_tag} #endif ); - m.def("set_seed", &RandomEngine::SetManualSeed); -#ifdef GRAPHBOLT_USE_CUDA - m.def("set_max_uva_threads", &cuda::set_max_uva_threads); -#endif } } // namespace sampling