Skip to content

Commit

Permalink
[Dataset] Chameleon (dmlc#5477)
Browse files Browse the repository at this point in the history
* update

* update

* update

* lint

* update

* CI

* lint

* update doc

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
  • Loading branch information
2 people authored and DominikaJedynak committed Mar 12, 2024
1 parent 11a6ea4 commit 6880a6d
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/dgl.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Datasets for node classification/regression tasks
YelpDataset
PATTERNDataset
CLUSTERDataset
ChameleonDataset

Edge Prediction Datasets
---------------------------------------
Expand Down
1 change: 1 addition & 0 deletions python/dgl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .utils import *
from .cluster import CLUSTERDataset
from .pattern import PATTERNDataset
from .wiki_network import ChameleonDataset
from .wikics import WikiCSDataset
from .yelp import YelpDataset
from .zinc import ZINCDataset
Expand Down
18 changes: 15 additions & 3 deletions python/dgl/data/dgl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import abc
import hashlib
import os
import sys
import traceback

from ..utils import retry_method_with_fix
Expand Down Expand Up @@ -221,6 +220,15 @@ def _get_hash(self):
hash_func.update(str(self._hash_key).encode("utf-8"))
return hash_func.hexdigest()[:8]

def _get_hash_url_suffix(self):
"""Get the suffix based on the hash value of the url."""
if self._url is None:
return ""
else:
hash_func = hashlib.sha1()
hash_func.update(str(self._url).encode("utf-8"))
return "_" + hash_func.hexdigest()[:8]

@property
def url(self):
r"""Get url to download the raw dataset."""
Expand All @@ -241,7 +249,9 @@ def raw_path(self):
r"""Directory contains the input data files.
By default raw_path = os.path.join(self.raw_dir, self.name)
"""
return os.path.join(self.raw_dir, self.name)
return os.path.join(
self.raw_dir, self.name + self._get_hash_url_suffix()
)

@property
def save_dir(self):
Expand All @@ -251,7 +261,9 @@ def save_dir(self):
@property
def save_path(self):
r"""Path to save the processed dataset."""
return os.path.join(self._save_dir, self.name)
return os.path.join(
self.save_dir, self.name + self._get_hash_url_suffix()
)

@property
def verbose(self):
Expand Down
11 changes: 2 additions & 9 deletions python/dgl/data/qm7b.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
"""QM7b dataset for graph property prediction (regression)."""
import os

import numpy as np
from scipy import io

from .. import backend as F
from ..convert import graph as dgl_graph

from .dgl_dataset import DGLDataset
from .utils import (
check_sha1,
deprecate_property,
download,
load_graphs,
save_graphs,
)
from .utils import check_sha1, download, load_graphs, save_graphs


class QM7bDataset(DGLDataset):
Expand Down Expand Up @@ -93,7 +86,7 @@ def __init__(
)

def process(self):
mat_path = self.raw_path + ".mat"
mat_path = os.path.join(self.raw_dir, self.name + ".mat")
self.graphs, self.label = self._load_graph(mat_path)

def _load_graph(self, filename):
Expand Down
184 changes: 184 additions & 0 deletions python/dgl/data/wiki_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
Wikipedia page-page networks on the chameleon topic.
"""
import os

import numpy as np

from ..convert import graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url


class WikiNetworkDataset(DGLBuiltinDataset):
r"""Wikipedia page-page networks from `Multi-scale Attributed
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by
`Geom-GCN: Geometric Graph Convolutional Networks
<https://arxiv.org/abs/2002.05287>`
Parameters
----------
name : str
Name of the dataset.
raw_dir : str
Raw file directory to store the processed data.
force_reload : bool
Whether to always generate the data from scratch rather than load a
cached version.
verbose : bool
Whether to print progress information.
transform : callable
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""

def __init__(self, name, raw_dir, force_reload, verbose, transform):
url = _get_dgl_url(f"dataset/{name}.zip")
super(WikiNetworkDataset, self).__init__(
name=name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)

def process(self):
"""Load and process the data."""
try:
import torch
except ImportError:
raise ModuleNotFoundError(
"This dataset requires PyTorch to be the backend."
)

# Process node features and labels.
with open(f"{self.raw_path}/out1_node_feature_label.txt", "r") as f:
data = f.read().split("\n")[1:-1]
features = [
[float(v) for v in r.split("\t")[1].split(",")] for r in data
]
features = torch.tensor(features, dtype=torch.float)
labels = [int(r.split("\t")[2]) for r in data]
self._num_classes = max(labels) + 1
labels = torch.tensor(labels, dtype=torch.long)

# Process graph structure.
with open(f"{self.raw_path}/out1_graph_edges.txt", "r") as f:
data = f.read().split("\n")[1:-1]
data = [[int(v) for v in r.split("\t")] for r in data]
dst, src = torch.tensor(data, dtype=torch.long).t().contiguous()

self._g = graph((src, dst), num_nodes=features.size(0))
self._g.ndata["feat"] = features
self._g.ndata["label"] = labels

# Process 10 train/val/test node splits.
train_masks, val_masks, test_masks = [], [], []
for i in range(10):
filepath = f"{self.raw_path}/{self.name}_split_0.6_0.2_{i}.npz"
f = np.load(filepath)
train_masks += [torch.from_numpy(f["train_mask"])]
val_masks += [torch.from_numpy(f["val_mask"])]
test_masks += [torch.from_numpy(f["test_mask"])]
self._g.ndata["train_mask"] = torch.stack(train_masks, dim=1).bool()
self._g.ndata["val_mask"] = torch.stack(val_masks, dim=1).bool()
self._g.ndata["test_mask"] = torch.stack(test_masks, dim=1).bool()

def has_cache(self):
return os.path.exists(self.raw_path)

def load(self):
self.process()

def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph."
if self._transform is None:
return self._g
else:
return self._transform(self._g)

def __len__(self):
return 1

@property
def num_classes(self):
return self._num_classes


class ChameleonDataset(WikiNetworkDataset):
r"""Wikipedia page-page network on chameleons from `Multi-scale Attributed
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by
`Geom-GCN: Geometric Graph Convolutional Networks
<https://arxiv.org/abs/2002.05287>`
Nodes represent articles from the English Wikipedia, edges reflect mutual
links between them. Node features indicate the presence of particular nouns
in the articles. The nodes were classified into 5 classes in terms of their
average monthly traffic.
Statistics:
- Nodes: 2277
- Edges: 36101
- Number of Classes: 5
- 10 splits with 60/20/20 train/val/test ratio
- Train: 1092
- Val: 729
- Test: 456
Parameters
----------
raw_dir : str, optional
Raw file directory to store the processed data. Default: ~/.dgl/
force_reload : bool, optional
Whether to always generate the data from scratch rather than load a
cached version. Default: False
verbose : bool, optional
Whether to print progress information. Default: True
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access. Default: None
Attributes
----------
num_classes : int
Number of node classes
Notes
-----
The graph does not come with edges for both directions.
Examples
--------
>>> from dgl.data import ChameleonDataset
>>> dataset = ChameleonDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get data split
>>> train_mask = g.ndata["train_mask"]
>>> val_mask = g.ndata["val_mask"]
>>> test_mask = g.ndata["test_mask"]
>>> # get labels
>>> label = g.ndata['label']
"""

def __init__(
self, raw_dir=None, force_reload=False, verbose=True, transform=None
):
super(ChameleonDataset, self).__init__(
name="chameleon",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
23 changes: 23 additions & 0 deletions tests/python/common/data/test_wiki_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest

import backend as F

import dgl


@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_chameleon():
transform = dgl.AddSelfLoop(allow_duplicate=True)

# chameleon
g = dgl.data.ChameleonDataset(force_reload=True)[0]
assert g.num_nodes() == 2277
assert g.num_edges() == 36101
g2 = dgl.data.ChameleonDataset(force_reload=True, transform=transform)[0]
assert g2.num_edges() - g.num_edges() == g.num_nodes()

0 comments on commit 6880a6d

Please sign in to comment.