Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial issue #28

Closed
gui-li opened this issue Jun 11, 2021 · 2 comments
Closed

Tutorial issue #28

gui-li opened this issue Jun 11, 2021 · 2 comments
Assignees
Labels
xgraph Interpretability of Graph Neural Networks

Comments

@gui-li
Copy link

gui-li commented Jun 11, 2021

The error comes out while executing your tutorial.
The code block:

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector, ExplanationProcessor
x_collector = XCollector()

index = -1
node_indices = torch.where(dataset[0].test_mask * dataset[0].y != 0)[0].tolist()
data = dataset[0]

from dig.xgraph.method.subgraphx import PlotUtils
from dig.xgraph.method.subgraphx import find_closest_node_result, k_hop_subgraph_with_default_whole_graph
plotutils = PlotUtils(dataset_name='ba_shapes')

# Visualization
max_nodes = 5
node_idx = node_indices[6]
print(f'explain graph node {node_idx}')
data.to(device)
logits = model(data.x, data.edge_index)
prediction = logits[node_idx].argmax(-1).item()

_, explanation_results, related_preds = \
    explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes)

plotutils = PlotUtils(dataset_name='ba_shapes')
explainer.visualization(explanation_results,
                        prediction,
                        max_nodes=max_nodes,
                        plot_utils=plotutils,
                        y=data.y)

The error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-7e5867836373> in <module>
     20 
     21 _, explanation_results, related_preds = \
---> 22     explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
     23 result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes)
     24 

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in __call__(self, x, edge_index, **kwargs)
    671                 payoff_func = self.get_reward_func(value_func, node_idx=self.mcts_state_map.node_idx)
    672                 self.mcts_state_map.set_score_func(payoff_func)
--> 673                 results = self.mcts_state_map.mcts(verbose=False)
    674 
    675                 tree_node_x = find_closest_node_result(results, max_nodes=max_nodes)

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in mcts(self, verbose)
    465             print(f"The nodes in graph is {self.graph.number_of_nodes()}")
    466         for rollout_idx in range(self.n_rollout):
--> 467             self.mcts_rollout(self.root)
    468             if verbose:
    469                 print(f"At the {rollout_idx} rollout, {len(self.state_map)} states that have been explored.")

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in mcts_rollout(self, tree_node)
    450                     tree_node.children.append(new_node)
    451 
--> 452             scores = compute_scores(self.score_func, tree_node.children)
    453             for child, score in zip(tree_node.children, scores):
    454                 child.P = score

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in compute_scores(score_func, children)
    163     for child in children:
    164         if child.P == 0:
--> 165             score = score_func(child.coalition, child.data)
    166         else:
    167             score = child.P

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in mc_l_shapley(coalition, data, local_raduis, value_func, subgraph_building_method, sample_num)
    216     include_mask = np.stack(set_include_masks, axis=0)
    217     marginal_contributions = \
--> 218         marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
    219 
    220     mc_l_shapley_value = (marginal_contributions).mean().item()

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
     73     marginal_contribution_list = []
     74 
---> 75     for exclude_data, include_data in dataloader:
     76         exclude_values = value_func(exclude_data)
     77         include_values = value_func(include_data)

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    515             if self._sampler_iter is None:
    516                 self._reset()
--> 517             data = self._next_data()
    518             self._num_yielded += 1
    519             if self._dataset_kind == _DatasetKind.Iterable and \

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    557         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    558         if self._pin_memory:
--> 559             data = _utils.pin_memory.pin_memory(data)
    560         return data
    561 

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data)
     53         return type(data)(*(pin_memory(sample) for sample in data))
     54     elif isinstance(data, container_abcs.Sequence):
---> 55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
     57         return data.pin_memory()

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in <listcomp>(.0)
     53         return type(data)(*(pin_memory(sample) for sample in data))
     54     elif isinstance(data, container_abcs.Sequence):
---> 55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
     57         return data.pin_memory()

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data)
     55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
---> 57         return data.pin_memory()
     58     else:
     59         return data

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in pin_memory(self, *keys)
    363         If :obj:`*keys` is not given, the conversion is applied to all present
    364         attributes."""
--> 365         return self.apply(lambda x: x.pin_memory(), *keys)
    366 
    367     def debug(self):

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in apply(self, func, *keys)
    324         """
    325         for key, item in self(*keys):
--> 326             self[key] = self.__apply__(item, func)
    327         return self
    328 

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in __apply__(self, item, func)
    303     def __apply__(self, item, func):
    304         if torch.is_tensor(item):
--> 305             return func(item)
    306         elif isinstance(item, SparseTensor):
    307             # Not all apply methods are supported for `SparseTensor`, e.g.,

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in <lambda>(x)
    363         If :obj:`*keys` is not given, the conversion is applied to all present
    364         attributes."""
--> 365         return self.apply(lambda x: x.pin_memory(), *keys)
    366 
    367     def debug(self):

RuntimeError: cannot pin 'torch.cuda.LongTensor' only dense CPU tensors can be pinned

My installed packages:

# packages in environment at /home/*/anaconda3/envs/dig:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             4.5                       1_gnu
anyio                     3.1.0            py38h578d9bd_0    conda-forge
argon2-cffi               20.1.0           py38h497a2fe_2    conda-forge
ase                       3.21.1                   pypi_0    pypi
async_generator           1.10                       py_0    conda-forge
attrs                     21.2.0             pyhd8ed1ab_0    conda-forge
babel                     2.9.1              pyh44b312d_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                        py_2    conda-forge
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
blas                      1.0                         mkl
bleach                    3.3.0              pyh44b312d_0    conda-forge
boost                     1.74.0           py38hc10631b_3    conda-forge
boost-cpp                 1.74.0               hc6e9bd1_3    conda-forge
brotlipy                  0.7.0           py38h497a2fe_1001    conda-forge
bzip2                     1.0.8                h7b6447c_0
ca-certificates           2021.5.30            ha878542_0    conda-forge
cairo                     1.16.0            h6cf1ce9_1008    conda-forge
captum                    0.2.0                    pypi_0    pypi
certifi                   2021.5.30        py38h578d9bd_0    conda-forge
cffi                      1.14.5           py38ha65f79e_0    conda-forge
chardet                   4.0.0            py38h578d9bd_1    conda-forge
cilog                     1.2.0                    pypi_0    pypi
cloudpickle               1.6.0                    pypi_0    pypi
cryptography              3.4.7            py38ha5dfef3_0    conda-forge
cudatoolkit               10.1.243             h6bb024c_0
cycler                    0.10.0                     py_2    conda-forge
decorator                 4.4.2                    pypi_0    pypi
defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
dive-into-graphs          0.0.4                    pypi_0    pypi
entrypoints               0.3             pyhd8ed1ab_1003    conda-forge
et-xmlfile                1.1.0                    pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
fontconfig                2.13.1            hba837de_1005    conda-forge
freetype                  2.10.4               h5ab3b9f_0
gettext                   0.19.8.1          h0b5b191_1005    conda-forge
gmp                       6.2.1                h2531618_2
gnutls                    3.6.15               he1e5248_0
googledrivedownloader     0.4                      pypi_0    pypi
greenlet                  1.1.0            py38h709712a_0    conda-forge
h5py                      3.2.1                    pypi_0    pypi
icu                       68.1                 h58526e2_0    conda-forge
idna                      2.10               pyh9f0ad1d_0    conda-forge
importlib-metadata        4.5.0            py38h578d9bd_0    conda-forge
intel-openmp              2021.2.0           h06a4308_610
ipykernel                 5.5.5            py38hd0cf306_0    conda-forge
ipython                   7.24.1           py38hd0cf306_0    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
isodate                   0.6.0                    pypi_0    pypi
jedi                      0.18.0           py38h578d9bd_2    conda-forge
jinja2                    3.0.1              pyhd8ed1ab_0    conda-forge
joblib                    1.0.1                    pypi_0    pypi
jpeg                      9b                   h024ee3a_2
json5                     0.9.5              pyh9f0ad1d_0    conda-forge
jsonschema                3.2.0              pyhd8ed1ab_3    conda-forge
jupyter_client            6.1.12             pyhd8ed1ab_0    conda-forge
jupyter_core              4.7.1            py38h578d9bd_0    conda-forge
jupyter_server            1.8.0              pyhd8ed1ab_0    conda-forge
jupyterlab                3.0.16             pyhd8ed1ab_0    conda-forge
jupyterlab_pygments       0.1.2              pyh9f0ad1d_0    conda-forge
jupyterlab_server         2.6.0              pyhd8ed1ab_0    conda-forge
kiwisolver                1.3.1            py38h1fd1430_1    conda-forge
lame                      3.100                h7b6447c_0
lcms2                     2.12                 h3be6417_0
ld_impl_linux-64          2.35.1               h7274673_9
libffi                    3.3                  he6710b0_2
libgcc-ng                 9.3.0               h5101ec6_17
libglib                   2.68.3               h3e27bee_0    conda-forge
libgomp                   9.3.0               h5101ec6_17
libiconv                  1.16                 h516909a_0    conda-forge
libidn2                   2.3.1                h27cfd23_0
libpng                    1.6.37               hbc83047_0
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              9.3.0               hd4cf53a_17
libtasn1                  4.16.0               h27cfd23_0
libtiff                   4.2.0                h85742a9_0
libunistring              0.9.10               h27cfd23_0
libuuid                   2.32.1            h7f98852_1000    conda-forge
libuv                     1.40.0               h7b6447c_0
libwebp-base              1.2.0                h27cfd23_0
libxcb                    1.13              h7f98852_1003    conda-forge
libxml2                   2.9.12               h72842e0_0    conda-forge
llvmlite                  0.36.0                   pypi_0    pypi
lz4-c                     1.9.3                h2531618_0
markupsafe                2.0.1            py38h497a2fe_0    conda-forge
matplotlib-base           3.4.2            py38hcc49a3a_0    conda-forge
matplotlib-inline         0.1.2              pyhd8ed1ab_2    conda-forge
mistune                   0.8.4           py38h497a2fe_1003    conda-forge
mkl                       2021.2.0           h06a4308_296
mkl-service               2.3.0            py38h27cfd23_1
mkl_fft                   1.3.0            py38h42c9631_2
mkl_random                1.2.1            py38ha9443f7_2
mypy-extensions           0.4.3                    pypi_0    pypi
nbclassic                 0.3.1              pyhd8ed1ab_1    conda-forge
nbclient                  0.5.3              pyhd8ed1ab_0    conda-forge
nbconvert                 6.0.7            py38h578d9bd_3    conda-forge
nbformat                  5.1.3              pyhd8ed1ab_0    conda-forge
ncurses                   6.2                  he6710b0_1
nest-asyncio              1.5.1              pyhd8ed1ab_0    conda-forge
nettle                    3.7.3                hbbd107a_1
networkx                  2.5.1                    pypi_0    pypi
ninja                     1.10.2               hff7bd54_1
notebook                  6.4.0              pyha770c72_0    conda-forge
numba                     0.53.1                   pypi_0    pypi
numpy                     1.20.2           py38h2d18471_0
numpy-base                1.20.2           py38hfae3a4d_0
olefile                   0.46                       py_0
openh264                  2.1.0                hd408876_0
openpyxl                  3.0.7                    pypi_0    pypi
openssl                   1.1.1k               h7f98852_0    conda-forge
packaging                 20.9               pyh44b312d_0    conda-forge
pandas                    1.2.4                    pypi_0    pypi
pandoc                    2.14.0.1             h7f98852_0    conda-forge
pandocfilters             1.4.2                      py_1    conda-forge
parso                     0.8.2              pyhd8ed1ab_0    conda-forge
pcre                      8.44                 he1b5a44_0    conda-forge
pexpect                   4.8.0              pyh9f0ad1d_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    8.2.0            py38he98fc37_0
pip                       21.1.2           py38h06a4308_0
pixman                    0.40.0               h36c2ea0_0    conda-forge
prometheus_client         0.11.0             pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.18             pyha770c72_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pycairo                   1.20.1           py38hf61ee4a_0    conda-forge
pycparser                 2.20               pyh9f0ad1d_2    conda-forge
pygments                  2.9.0              pyhd8ed1ab_0    conda-forge
pyopenssl                 20.0.1             pyhd8ed1ab_0    conda-forge
pyparsing                 2.4.7              pyh9f0ad1d_0    conda-forge
pyrsistent                0.17.3           py38h497a2fe_2    conda-forge
pysocks                   1.7.1            py38h578d9bd_3    conda-forge
python                    3.8.10               h12debd9_8
python-dateutil           2.8.1                      py_0    conda-forge
python-louvain            0.15                     pypi_0    pypi
python_abi                3.8                      1_cp38    conda-forge
pytorch                   1.8.1           py3.8_cuda10.1_cudnn7.6.3_0    pytorch
pytz                      2021.1             pyhd8ed1ab_0    conda-forge
pyzmq                     22.1.0           py38h2035c66_0    conda-forge
rdflib                    5.0.0                    pypi_0    pypi
rdkit                     2021.03.3        py38hf8acc3d_0    conda-forge
readline                  8.1                  h27cfd23_0
reportlab                 3.5.67           py38hadf75a6_0    conda-forge
requests                  2.25.1             pyhd3deb0d_0    conda-forge
scikit-learn              0.24.2                   pypi_0    pypi
scipy                     1.6.3                    pypi_0    pypi
send2trash                1.5.0                      py_0    conda-forge
setuptools                52.0.0           py38h06a4308_0
shap                      0.39.0                   pypi_0    pypi
six                       1.15.0           py38h06a4308_0
slicer                    0.0.7                    pypi_0    pypi
sniffio                   1.2.0            py38h578d9bd_1    conda-forge
sqlalchemy                1.4.18           py38h497a2fe_0    conda-forge
sqlite                    3.35.4               hdfb4753_0
tabulate                  0.8.9                    pypi_0    pypi
terminado                 0.10.1           py38h578d9bd_0    conda-forge
testpath                  0.5.0              pyhd8ed1ab_0    conda-forge
threadpoolctl             2.1.0                    pypi_0    pypi
tk                        8.6.10               hbc83047_0
torch-cluster             1.5.9                    pypi_0    pypi
torch-geometric           1.7.0                    pypi_0    pypi
torch-scatter             2.0.7                    pypi_0    pypi
torch-sparse              0.6.9                    pypi_0    pypi
torch-spline-conv         1.2.1                    pypi_0    pypi
torchaudio                0.8.1                      py38    pytorch
torchvision               0.9.1                py38_cu101    pytorch
tornado                   6.1              py38h497a2fe_1    conda-forge
tqdm                      4.61.0                   pypi_0    pypi
traitlets                 5.0.5                      py_0    conda-forge
typed-argument-parser     1.5.4                    pypi_0    pypi
typing-inspect            0.7.1                    pypi_0    pypi
typing_extensions         3.7.4.3            pyha847dfd_0
tzdata                    2020f                h52ac0ba_0
urllib3                   1.26.5             pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.5              pyh9f0ad1d_2    conda-forge
webencodings              0.5.1                      py_1    conda-forge
websocket-client          0.57.0           py38h578d9bd_4    conda-forge
wheel                     0.36.2             pyhd3eb1b0_0
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.0.10               h7f98852_0    conda-forge
xorg-libsm                1.2.3             hd9c2040_1000    conda-forge
xorg-libx11               1.7.2                h7f98852_0    conda-forge
xorg-libxau               1.0.9                h7f98852_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h7f98852_1    conda-forge
xorg-libxrender           0.9.10            h7f98852_1003    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h7f98852_1002    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.5                h7b6447c_0
zeromq                    4.3.4                h9c3ff4c_0    conda-forge
zipp                      3.4.1              pyhd8ed1ab_0    conda-forge
zlib                      1.2.11               h7b6447c_3
zstd                      1.4.9                haebb681_0

I have installed the latest version of DIG from source.

@mengliu1998 mengliu1998 added the xgraph Interpretability of Graph Neural Networks label Jun 13, 2021
@Oceanusity
Copy link
Collaborator

Oceanusity commented Jun 13, 2021

I have disabled the pin_memory flag in the Dataloader. I think it will solve the problem, and feel free to report problems.

@gui-li
Copy link
Author

gui-li commented Jun 13, 2021

@Oceanusity Thanks again for your work. The library now works fine with the provided tutorial.

@gui-li gui-li closed this as completed Jun 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
xgraph Interpretability of Graph Neural Networks
Projects
None yet
Development

No branches or pull requests

3 participants