Skip to content

Commit

Permalink
pickle and ADIOS2 formatting fixed for OGB dataset (ORNL#191)
Browse files Browse the repository at this point in the history
* pickle formatting fixed for OGB dataset

* update on adios

---------

Co-authored-by: Choi <choij@ornl.gov>
  • Loading branch information
allaffa and jychoi-hpc authored Sep 11, 2023
1 parent df51d10 commit a0d2671
Showing 1 changed file with 147 additions and 66 deletions.
213 changes: 147 additions & 66 deletions examples/ogb/train_gap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os, json
import matplotlib.pyplot as plt
import random
import pandas
import pickle, csv

import logging
Expand All @@ -12,9 +13,11 @@
import time

import hydragnn
from hydragnn.preprocess.load_data import split_dataset
from hydragnn.utils.print_utils import print_distributed, iterate_tqdm
from hydragnn.utils.time_utils import Timer
from hydragnn.utils.pickledataset import SimplePickleDataset
from hydragnn.utils.pickledataset import SimplePickleWriter, SimplePickleDataset
from hydragnn.preprocess.utils import gather_deg
from hydragnn.utils.model import print_model
from hydragnn.utils.smiles_utils import (
get_node_attribute_name,
Expand All @@ -33,9 +36,9 @@
import torch
import torch.distributed as dist

import warnings
# import warnings

warnings.filterwarnings("error")
# warnings.filterwarnings("error")

ogb_node_types = {
"H": 0,
Expand Down Expand Up @@ -76,6 +79,71 @@ def info(*args, logtype="info", sep=" "):
getattr(logging, logtype)(sep.join(map(str, args)))


from hydragnn.utils.abstractbasedataset import AbstractBaseDataset


def smiles_to_graph(datadir, files_list):

subset = []

for filename in files_list:

df = pandas.read_csv(os.path.join(datadir, filename))
rx = list(nsplit(range(len(df)), comm_size))[rank]

for smile_id in range(len(df))[rx.start : rx.stop]:
## get atomic positions and numbers
dfrow = df.iloc[smile_id]

smilestr = dfrow[0]
ytarget = (
torch.tensor(float(dfrow[-1]))
.unsqueeze(0)
.unsqueeze(1)
.to(torch.float32)
) # HL gap

data = generate_graphdata_from_smilestr(
smilestr,
ytarget,
ogb_node_types,
var_config,
)

subset.append(data)

return subset


class OGBDataset(AbstractBaseDataset):
"""OGBDataset dataset class"""

def __init__(self, dirpath, var_config, dist=False):
super().__init__()

self.var_config = var_config
self.dist = dist
if self.dist:
assert torch.distributed.is_initialized()
self.world_size = torch.distributed.get_world_size()
self.rank = torch.distributed.get_rank()

if os.path.isdir(dirpath):
dirfiles = sorted(os.listdir(dirpath))
else:
raise ValueError("OGBDataset takes dirpath as directory")

setids_files = [x for x in dirfiles if x.endswith("csv")]

self.dataset.extend(smiles_to_graph(dirpath, setids_files))

def len(self):
return len(self.dataset)

def get(self, idx):
return self.dataset[idx]


def ogb_datasets_load(datafile, sampling=None, seed=None):
if seed is not None:
random.seed(seed)
Expand Down Expand Up @@ -121,8 +189,9 @@ def __init__(self, datafile, var_config, sampling=1.0, seed=43, norm_yflag=False
smiles_sets, values_sets = ogb_datasets_load(
datafile, sampling=sampling, seed=seed
)
ymean = var_config["ymean"]
ystd = var_config["ystd"]
if norm_yflag:
ymean = var_config["ymean"]
ystd = var_config["ystd"]

info([len(x) for x in values_sets])
self.dataset_lists = list()
Expand Down Expand Up @@ -196,7 +265,7 @@ def __getitem__(self, idx):
graph_feature_names = ["GAP"]
graph_feature_dim = [1]
dirpwd = os.path.dirname(os.path.abspath(__file__))
datafile = os.path.join(dirpwd, "dataset/pcqm4m_gap.csv")
datadir = os.path.join(dirpwd, "dataset/")
##################################################################################################################
inputfilesubstr = args.inputfilesubstr
input_filename = os.path.join(dirpwd, "ogb_" + inputfilesubstr + ".json")
Expand All @@ -216,6 +285,7 @@ def __getitem__(self, idx):
var_config["input_node_feature_names"],
var_config["input_node_feature_dims"],
) = get_node_attribute_name(ogb_node_types)
var_config["node_feature_dims"] = var_config["input_node_feature_dims"]
##################################################################################################################
# Always initialize for multi-rank training.
comm_size, rank = hydragnn.utils.setup_ddp()
Expand All @@ -230,73 +300,75 @@ def __getitem__(self, idx):
datefmt="%H:%M:%S",
)

log_name = "ogb_" + inputfilesubstr + "_eV_fullx"
log_name = "ogb_" + inputfilesubstr
hydragnn.utils.setup_log(log_name)
writer = hydragnn.utils.get_summary_writer(log_name)
hydragnn.utils.save_config(config, log_name)

modelname = "ogb_" + inputfilesubstr
if args.preonly:
norm_yflag = False # True
smiles_sets, values_sets = ogb_datasets_load(
datafile, sampling=args.sampling, seed=43

## local data
total = OGBDataset(
os.path.join(datadir),
var_config,
dist=True,
)
info([len(x) for x in values_sets])
dataset_lists = [[] for dataset in values_sets]
for idataset, (smileset, valueset) in enumerate(zip(smiles_sets, values_sets)):
if norm_yflag:
valueset = (
valueset - torch.tensor(var_config["ymean"])
) / torch.tensor(var_config["ystd"])

rx = list(nsplit(range(len(smileset)), comm_size))[rank]
info("subset range:", idataset, len(smileset), rx.start, rx.stop)
## local portion
_smileset = smileset[rx.start : rx.stop]
_valueset = valueset[rx.start : rx.stop]
info("local smileset size:", len(_smileset))

setname = ["trainset", "valset", "testset"]
if args.format == "pickle":
dirname = os.path.join(os.path.dirname(__file__), "dataset", "pickle")
if rank == 0:
if not os.path.exists(dirname):
os.makedirs(dirname)
with open("%s/%s.meta" % (dirname, setname[idataset]), "w") as f:
f.write(str(len(smileset)))

for i, (smilestr, ytarget) in iterate_tqdm(
enumerate(zip(_smileset, _valueset)), verbosity, total=len(_smileset)
):
data = generate_graphdata_from_smilestr(
smilestr,
ytarget,
ogb_node_types,
var_config,
)
dataset_lists[idataset].append(data)

## (2022/07) This is for testing to compare with Adios
## pickle write
if args.format == "pickle":
fname = "%s/ogb_gap-%s-%d.pk" % (
dirname,
setname[idataset],
rx.start + i,
)
with open(fname, "wb") as f:
pickle.dump(data, f)
## This is a local split
trainset, valset, testset = split_dataset(
dataset=total,
perc_train=0.9,
stratify_splitting=False,
)
print("Local splitting: ", len(total), len(trainset), len(valset), len(testset))

deg = gather_deg(trainset)
config["pna_deg"] = deg

setnames = ["trainset", "valset", "testset"]

## local data
if args.format == "adios":
_trainset = dataset_lists[0]
_valset = dataset_lists[1]
_testset = dataset_lists[2]
if args.format == "pickle":

## pickle
basedir = os.path.join(
os.path.dirname(__file__), "dataset", "%s.pickle" % modelname
)
attrs = dict()
attrs["pna_deg"] = deg
SimplePickleWriter(
trainset,
basedir,
"trainset",
# minmax_node_feature=total.minmax_node_feature,
# minmax_graph_feature=total.minmax_graph_feature,
use_subdir=True,
attrs=attrs,
)
SimplePickleWriter(
valset,
basedir,
"valset",
# minmax_node_feature=total.minmax_node_feature,
# minmax_graph_feature=total.minmax_graph_feature,
use_subdir=True,
)
SimplePickleWriter(
testset,
basedir,
"testset",
# minmax_node_feature=total.minmax_node_feature,
# minmax_graph_feature=total.minmax_graph_feature,
use_subdir=True,
)

if args.format == "adios":
fname = os.path.join(os.path.dirname(__file__), "dataset", "ogb_gap.bp")
adwriter = AdiosWriter(fname, comm)
adwriter.add("trainset", _trainset)
adwriter.add("valset", _valset)
adwriter.add("testset", _testset)
adwriter.add("trainset", trainset)
adwriter.add("valset", valset)
adwriter.add("testset", testset)
adwriter.save()

sys.exit(0)
Expand All @@ -320,14 +392,23 @@ def __getitem__(self, idx):
valset = OGBRawDataset(fact, "valset")
testset = OGBRawDataset(fact, "testset")
elif args.format == "pickle":
dirname = os.path.join(os.path.dirname(__file__), "dataset", "pickle")
trainset = SimplePickleDataset(dirname, "ogb_gap", "trainset")
valset = SimplePickleDataset(dirname, "ogb_gap", "valset")
testset = SimplePickleDataset(dirname, "ogb_gap", "testset")
info("Pickle load")
basedir = os.path.join(
os.path.dirname(__file__), "dataset", "%s.pickle" % modelname
)
trainset = SimplePickleDataset(
basedir=basedir, label="trainset", var_config=var_config
)
valset = SimplePickleDataset(
basedir=basedir, label="valset", var_config=var_config
)
testset = SimplePickleDataset(
basedir=basedir, label="testset", var_config=var_config
)
pna_deg = trainset.pna_deg
else:
raise NotImplementedError("No supported format: %s" % (args.format))

info("Adios load")
info(
"trainset,valset,testset size: %d %d %d"
% (len(trainset), len(valset), len(testset))
Expand Down

0 comments on commit a0d2671

Please sign in to comment.