We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The error comes out while executing your tutorial. The code block:
# --- Create data collector and explanation processor --- from dig.xgraph.evaluation import XCollector, ExplanationProcessor x_collector = XCollector() index = -1 node_indices = torch.where(dataset[0].test_mask * dataset[0].y != 0)[0].tolist() data = dataset[0] from dig.xgraph.method.subgraphx import PlotUtils from dig.xgraph.method.subgraphx import find_closest_node_result, k_hop_subgraph_with_default_whole_graph plotutils = PlotUtils(dataset_name='ba_shapes') # Visualization max_nodes = 5 node_idx = node_indices[6] print(f'explain graph node {node_idx}') data.to(device) logits = model(data.x, data.edge_index) prediction = logits[node_idx].argmax(-1).item() _, explanation_results, related_preds = \ explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes) result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes) plotutils = PlotUtils(dataset_name='ba_shapes') explainer.visualization(explanation_results, prediction, max_nodes=max_nodes, plot_utils=plotutils, y=data.y)
The error message:
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-6-7e5867836373> in <module> 20 21 _, explanation_results, related_preds = \ ---> 22 explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes) 23 result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes) 24 ~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in __call__(self, x, edge_index, **kwargs) 671 payoff_func = self.get_reward_func(value_func, node_idx=self.mcts_state_map.node_idx) 672 self.mcts_state_map.set_score_func(payoff_func) --> 673 results = self.mcts_state_map.mcts(verbose=False) 674 675 tree_node_x = find_closest_node_result(results, max_nodes=max_nodes) ~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in mcts(self, verbose) 465 print(f"The nodes in graph is {self.graph.number_of_nodes()}") 466 for rollout_idx in range(self.n_rollout): --> 467 self.mcts_rollout(self.root) 468 if verbose: 469 print(f"At the {rollout_idx} rollout, {len(self.state_map)} states that have been explored.") ~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in mcts_rollout(self, tree_node) 450 tree_node.children.append(new_node) 451 --> 452 scores = compute_scores(self.score_func, tree_node.children) 453 for child, score in zip(tree_node.children, scores): 454 child.P = score ~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in compute_scores(score_func, children) 163 for child in children: 164 if child.P == 0: --> 165 score = score_func(child.coalition, child.data) 166 else: 167 score = child.P ~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in mc_l_shapley(coalition, data, local_raduis, value_func, subgraph_building_method, sample_num) 216 include_mask = np.stack(set_include_masks, axis=0) 217 marginal_contributions = \ --> 218 marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func) 219 220 mc_l_shapley_value = (marginal_contributions).mean().item() ~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func) 73 marginal_contribution_list = [] 74 ---> 75 for exclude_data, include_data in dataloader: 76 exclude_values = value_func(exclude_data) 77 include_values = value_func(include_data) ~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self) 515 if self._sampler_iter is None: 516 self._reset() --> 517 data = self._next_data() 518 self._num_yielded += 1 519 if self._dataset_kind == _DatasetKind.Iterable and \ ~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self) 557 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 558 if self._pin_memory: --> 559 data = _utils.pin_memory.pin_memory(data) 560 return data 561 ~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data) 53 return type(data)(*(pin_memory(sample) for sample in data)) 54 elif isinstance(data, container_abcs.Sequence): ---> 55 return [pin_memory(sample) for sample in data] 56 elif hasattr(data, "pin_memory"): 57 return data.pin_memory() ~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in <listcomp>(.0) 53 return type(data)(*(pin_memory(sample) for sample in data)) 54 elif isinstance(data, container_abcs.Sequence): ---> 55 return [pin_memory(sample) for sample in data] 56 elif hasattr(data, "pin_memory"): 57 return data.pin_memory() ~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data) 55 return [pin_memory(sample) for sample in data] 56 elif hasattr(data, "pin_memory"): ---> 57 return data.pin_memory() 58 else: 59 return data ~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in pin_memory(self, *keys) 363 If :obj:`*keys` is not given, the conversion is applied to all present 364 attributes.""" --> 365 return self.apply(lambda x: x.pin_memory(), *keys) 366 367 def debug(self): ~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in apply(self, func, *keys) 324 """ 325 for key, item in self(*keys): --> 326 self[key] = self.__apply__(item, func) 327 return self 328 ~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in __apply__(self, item, func) 303 def __apply__(self, item, func): 304 if torch.is_tensor(item): --> 305 return func(item) 306 elif isinstance(item, SparseTensor): 307 # Not all apply methods are supported for `SparseTensor`, e.g., ~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in <lambda>(x) 363 If :obj:`*keys` is not given, the conversion is applied to all present 364 attributes.""" --> 365 return self.apply(lambda x: x.pin_memory(), *keys) 366 367 def debug(self): RuntimeError: cannot pin 'torch.cuda.LongTensor' only dense CPU tensors can be pinned
My installed packages:
# packages in environment at /home/*/anaconda3/envs/dig: # # Name Version Build Channel _libgcc_mutex 0.1 main _openmp_mutex 4.5 1_gnu anyio 3.1.0 py38h578d9bd_0 conda-forge argon2-cffi 20.1.0 py38h497a2fe_2 conda-forge ase 3.21.1 pypi_0 pypi async_generator 1.10 py_0 conda-forge attrs 21.2.0 pyhd8ed1ab_0 conda-forge babel 2.9.1 pyh44b312d_0 conda-forge backcall 0.2.0 pyh9f0ad1d_0 conda-forge backports 1.0 py_2 conda-forge backports.functools_lru_cache 1.6.4 pyhd8ed1ab_0 conda-forge blas 1.0 mkl bleach 3.3.0 pyh44b312d_0 conda-forge boost 1.74.0 py38hc10631b_3 conda-forge boost-cpp 1.74.0 hc6e9bd1_3 conda-forge brotlipy 0.7.0 py38h497a2fe_1001 conda-forge bzip2 1.0.8 h7b6447c_0 ca-certificates 2021.5.30 ha878542_0 conda-forge cairo 1.16.0 h6cf1ce9_1008 conda-forge captum 0.2.0 pypi_0 pypi certifi 2021.5.30 py38h578d9bd_0 conda-forge cffi 1.14.5 py38ha65f79e_0 conda-forge chardet 4.0.0 py38h578d9bd_1 conda-forge cilog 1.2.0 pypi_0 pypi cloudpickle 1.6.0 pypi_0 pypi cryptography 3.4.7 py38ha5dfef3_0 conda-forge cudatoolkit 10.1.243 h6bb024c_0 cycler 0.10.0 py_2 conda-forge decorator 4.4.2 pypi_0 pypi defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge dive-into-graphs 0.0.4 pypi_0 pypi entrypoints 0.3 pyhd8ed1ab_1003 conda-forge et-xmlfile 1.1.0 pypi_0 pypi ffmpeg 4.3 hf484d3e_0 pytorch fontconfig 2.13.1 hba837de_1005 conda-forge freetype 2.10.4 h5ab3b9f_0 gettext 0.19.8.1 h0b5b191_1005 conda-forge gmp 6.2.1 h2531618_2 gnutls 3.6.15 he1e5248_0 googledrivedownloader 0.4 pypi_0 pypi greenlet 1.1.0 py38h709712a_0 conda-forge h5py 3.2.1 pypi_0 pypi icu 68.1 h58526e2_0 conda-forge idna 2.10 pyh9f0ad1d_0 conda-forge importlib-metadata 4.5.0 py38h578d9bd_0 conda-forge intel-openmp 2021.2.0 h06a4308_610 ipykernel 5.5.5 py38hd0cf306_0 conda-forge ipython 7.24.1 py38hd0cf306_0 conda-forge ipython_genutils 0.2.0 py_1 conda-forge isodate 0.6.0 pypi_0 pypi jedi 0.18.0 py38h578d9bd_2 conda-forge jinja2 3.0.1 pyhd8ed1ab_0 conda-forge joblib 1.0.1 pypi_0 pypi jpeg 9b h024ee3a_2 json5 0.9.5 pyh9f0ad1d_0 conda-forge jsonschema 3.2.0 pyhd8ed1ab_3 conda-forge jupyter_client 6.1.12 pyhd8ed1ab_0 conda-forge jupyter_core 4.7.1 py38h578d9bd_0 conda-forge jupyter_server 1.8.0 pyhd8ed1ab_0 conda-forge jupyterlab 3.0.16 pyhd8ed1ab_0 conda-forge jupyterlab_pygments 0.1.2 pyh9f0ad1d_0 conda-forge jupyterlab_server 2.6.0 pyhd8ed1ab_0 conda-forge kiwisolver 1.3.1 py38h1fd1430_1 conda-forge lame 3.100 h7b6447c_0 lcms2 2.12 h3be6417_0 ld_impl_linux-64 2.35.1 h7274673_9 libffi 3.3 he6710b0_2 libgcc-ng 9.3.0 h5101ec6_17 libglib 2.68.3 h3e27bee_0 conda-forge libgomp 9.3.0 h5101ec6_17 libiconv 1.16 h516909a_0 conda-forge libidn2 2.3.1 h27cfd23_0 libpng 1.6.37 hbc83047_0 libsodium 1.0.18 h36c2ea0_1 conda-forge libstdcxx-ng 9.3.0 hd4cf53a_17 libtasn1 4.16.0 h27cfd23_0 libtiff 4.2.0 h85742a9_0 libunistring 0.9.10 h27cfd23_0 libuuid 2.32.1 h7f98852_1000 conda-forge libuv 1.40.0 h7b6447c_0 libwebp-base 1.2.0 h27cfd23_0 libxcb 1.13 h7f98852_1003 conda-forge libxml2 2.9.12 h72842e0_0 conda-forge llvmlite 0.36.0 pypi_0 pypi lz4-c 1.9.3 h2531618_0 markupsafe 2.0.1 py38h497a2fe_0 conda-forge matplotlib-base 3.4.2 py38hcc49a3a_0 conda-forge matplotlib-inline 0.1.2 pyhd8ed1ab_2 conda-forge mistune 0.8.4 py38h497a2fe_1003 conda-forge mkl 2021.2.0 h06a4308_296 mkl-service 2.3.0 py38h27cfd23_1 mkl_fft 1.3.0 py38h42c9631_2 mkl_random 1.2.1 py38ha9443f7_2 mypy-extensions 0.4.3 pypi_0 pypi nbclassic 0.3.1 pyhd8ed1ab_1 conda-forge nbclient 0.5.3 pyhd8ed1ab_0 conda-forge nbconvert 6.0.7 py38h578d9bd_3 conda-forge nbformat 5.1.3 pyhd8ed1ab_0 conda-forge ncurses 6.2 he6710b0_1 nest-asyncio 1.5.1 pyhd8ed1ab_0 conda-forge nettle 3.7.3 hbbd107a_1 networkx 2.5.1 pypi_0 pypi ninja 1.10.2 hff7bd54_1 notebook 6.4.0 pyha770c72_0 conda-forge numba 0.53.1 pypi_0 pypi numpy 1.20.2 py38h2d18471_0 numpy-base 1.20.2 py38hfae3a4d_0 olefile 0.46 py_0 openh264 2.1.0 hd408876_0 openpyxl 3.0.7 pypi_0 pypi openssl 1.1.1k h7f98852_0 conda-forge packaging 20.9 pyh44b312d_0 conda-forge pandas 1.2.4 pypi_0 pypi pandoc 2.14.0.1 h7f98852_0 conda-forge pandocfilters 1.4.2 py_1 conda-forge parso 0.8.2 pyhd8ed1ab_0 conda-forge pcre 8.44 he1b5a44_0 conda-forge pexpect 4.8.0 pyh9f0ad1d_2 conda-forge pickleshare 0.7.5 py_1003 conda-forge pillow 8.2.0 py38he98fc37_0 pip 21.1.2 py38h06a4308_0 pixman 0.40.0 h36c2ea0_0 conda-forge prometheus_client 0.11.0 pyhd8ed1ab_0 conda-forge prompt-toolkit 3.0.18 pyha770c72_0 conda-forge pthread-stubs 0.4 h36c2ea0_1001 conda-forge ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge pycairo 1.20.1 py38hf61ee4a_0 conda-forge pycparser 2.20 pyh9f0ad1d_2 conda-forge pygments 2.9.0 pyhd8ed1ab_0 conda-forge pyopenssl 20.0.1 pyhd8ed1ab_0 conda-forge pyparsing 2.4.7 pyh9f0ad1d_0 conda-forge pyrsistent 0.17.3 py38h497a2fe_2 conda-forge pysocks 1.7.1 py38h578d9bd_3 conda-forge python 3.8.10 h12debd9_8 python-dateutil 2.8.1 py_0 conda-forge python-louvain 0.15 pypi_0 pypi python_abi 3.8 1_cp38 conda-forge pytorch 1.8.1 py3.8_cuda10.1_cudnn7.6.3_0 pytorch pytz 2021.1 pyhd8ed1ab_0 conda-forge pyzmq 22.1.0 py38h2035c66_0 conda-forge rdflib 5.0.0 pypi_0 pypi rdkit 2021.03.3 py38hf8acc3d_0 conda-forge readline 8.1 h27cfd23_0 reportlab 3.5.67 py38hadf75a6_0 conda-forge requests 2.25.1 pyhd3deb0d_0 conda-forge scikit-learn 0.24.2 pypi_0 pypi scipy 1.6.3 pypi_0 pypi send2trash 1.5.0 py_0 conda-forge setuptools 52.0.0 py38h06a4308_0 shap 0.39.0 pypi_0 pypi six 1.15.0 py38h06a4308_0 slicer 0.0.7 pypi_0 pypi sniffio 1.2.0 py38h578d9bd_1 conda-forge sqlalchemy 1.4.18 py38h497a2fe_0 conda-forge sqlite 3.35.4 hdfb4753_0 tabulate 0.8.9 pypi_0 pypi terminado 0.10.1 py38h578d9bd_0 conda-forge testpath 0.5.0 pyhd8ed1ab_0 conda-forge threadpoolctl 2.1.0 pypi_0 pypi tk 8.6.10 hbc83047_0 torch-cluster 1.5.9 pypi_0 pypi torch-geometric 1.7.0 pypi_0 pypi torch-scatter 2.0.7 pypi_0 pypi torch-sparse 0.6.9 pypi_0 pypi torch-spline-conv 1.2.1 pypi_0 pypi torchaudio 0.8.1 py38 pytorch torchvision 0.9.1 py38_cu101 pytorch tornado 6.1 py38h497a2fe_1 conda-forge tqdm 4.61.0 pypi_0 pypi traitlets 5.0.5 py_0 conda-forge typed-argument-parser 1.5.4 pypi_0 pypi typing-inspect 0.7.1 pypi_0 pypi typing_extensions 3.7.4.3 pyha847dfd_0 tzdata 2020f h52ac0ba_0 urllib3 1.26.5 pyhd8ed1ab_0 conda-forge wcwidth 0.2.5 pyh9f0ad1d_2 conda-forge webencodings 0.5.1 py_1 conda-forge websocket-client 0.57.0 py38h578d9bd_4 conda-forge wheel 0.36.2 pyhd3eb1b0_0 xorg-kbproto 1.0.7 h7f98852_1002 conda-forge xorg-libice 1.0.10 h7f98852_0 conda-forge xorg-libsm 1.2.3 hd9c2040_1000 conda-forge xorg-libx11 1.7.2 h7f98852_0 conda-forge xorg-libxau 1.0.9 h7f98852_0 conda-forge xorg-libxdmcp 1.1.3 h7f98852_0 conda-forge xorg-libxext 1.3.4 h7f98852_1 conda-forge xorg-libxrender 0.9.10 h7f98852_1003 conda-forge xorg-renderproto 0.11.1 h7f98852_1002 conda-forge xorg-xextproto 7.3.0 h7f98852_1002 conda-forge xorg-xproto 7.0.31 h7f98852_1007 conda-forge xz 5.2.5 h7b6447c_0 zeromq 4.3.4 h9c3ff4c_0 conda-forge zipp 3.4.1 pyhd8ed1ab_0 conda-forge zlib 1.2.11 h7b6447c_3 zstd 1.4.9 haebb681_0
I have installed the latest version of DIG from source.
The text was updated successfully, but these errors were encountered:
I have disabled the pin_memory flag in the Dataloader. I think it will solve the problem, and feel free to report problems.
pin_memory
Sorry, something went wrong.
@Oceanusity Thanks again for your work. The library now works fine with the provided tutorial.
Oceanusity
No branches or pull requests
The error comes out while executing your tutorial.
The code block:
The error message:
My installed packages:
I have installed the latest version of DIG from source.
The text was updated successfully, but these errors were encountered: