Skip to content

Commit

Permalink
fix the tensor idx
Browse files Browse the repository at this point in the history
  • Loading branch information
paoxiaode committed Oct 8, 2023
1 parent f4a6cff commit 5c8bd3c
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions python/dgl/data/lrgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,26 @@ class PeptidesStructuralDataset(DGLDataset):
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
>>> # accept tensor to be index, but will ignore transform parameter
>>> # get train dataset
>>> split_dict = dataset.get_idx_split()
>>> trainset = dataset[split_dict["train"]]
>>> graph, label = trainset[0]
>>> graph
Graph(num_nodes=338, num_edges=682,
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
>>> # get subset of dataset
>>> import torch
>>> idx = torch.tensor([0, 1, 2])
>>> dataset_subset = dataset[idx]
>>> graph, label = dataset_subset[0]
>>> graph
Graph(num_nodes=119, num_edges=244,
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
"""

def __init__(
Expand Down Expand Up @@ -309,6 +322,9 @@ class PeptidesFunctionalDataset(DGLDataset):
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
>>> # accept tensor to be index, but will ignore transform parameter
>>> # get train dataset
>>> split_dict = dataset.get_idx_split()
>>> trainset = dataset[split_dict["train"]]
>>> graph, label = trainset[0]
Expand All @@ -317,6 +333,15 @@ class PeptidesFunctionalDataset(DGLDataset):
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
>>> # get subset of dataset
>>> import torch
>>> idx = torch.tensor([0, 1, 2])
>>> dataset_subset = dataset[idx]
>>> graph, label = dataset_subset[0]
>>> graph
Graph(num_nodes=119, num_edges=244,
ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
"""

def __init__(
Expand Down Expand Up @@ -491,8 +516,7 @@ class VOCSuperpixelsDataset(DGLDataset):
Default: "train".
construct_format : str, optional
Option to select the graph construction format.
Should be chosen from ["edge_wt_only_coord", "edge_wt_coord_feat",
"edge_wt_region_boundary"]
Should be chosen from the following formats:
"edge_wt_only_coord": the graphs are 8-nn graphs with the edge weights
computed based on only spatial coordinates of superpixel nodes.
"edge_wt_coord_feat": the graphs are 8-nn graphs with the edge weights
Expand Down Expand Up @@ -528,6 +552,15 @@ class VOCSuperpixelsDataset(DGLDataset):
21
>>> graph = train_dataset[0]
>>> graph
Graph(num_nodes=460, num_edges=2632,
ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int32)}
edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})
>>> # accept tensor to be index, but will ignore transform parameter
>>> import torch
>>> idx = torch.tensor([0, 1, 2])
>>> train_dataset_subset = train_dataset[idx]
>>> train_dataset_subset[0]
Graph(num_nodes=460, num_edges=2632,
ndata_schemes={'feat': Scheme(shape=(14,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int32)}
edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})
Expand Down Expand Up @@ -570,13 +603,13 @@ def __init__(
):
self.construct_format = construct_format
self.slic_compactness = slic_compactness
assert split in ["train", "val", "test"]
assert split in ["train", "val", "test"], "split not valid."
assert construct_format in [
"edge_wt_only_coord",
"edge_wt_coord_feat",
"edge_wt_region_boundary",
]
assert slic_compactness in [10, 30]
], "construct_format not valid."
assert slic_compactness in [10, 30], "slic_compactness not valid."
self.split = split
super(VOCSuperpixelsDataset, self).__init__(
name="PascalVOC-SP",
Expand Down

0 comments on commit 5c8bd3c

Please sign in to comment.