Skip to content

Commit

Permalink
utilies to update data objects moved to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
allaffa committed Aug 23, 2023
1 parent 36682cc commit 10c0b15
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 66 deletions.
3 changes: 2 additions & 1 deletion hydragnn/preprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
get_radius_graph_config,
get_radius_graph_pbc_config,
RadiusGraphPBC,
update_predicted_values,
update_atom_features
)

from .load_data import (
Expand All @@ -20,7 +22,6 @@
)
from .serialized_dataset_loader import (
SerializedDataLoader,
update_predicted_values,
)
from .lsms_raw_dataset_loader import LSMS_RawDataLoader
from .cfg_raw_dataset_loader import CFG_RawDataLoader
Expand Down
65 changes: 2 additions & 63 deletions hydragnn/preprocess/serialized_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@
PointPairFeatures,
)

from .dataset_descriptors import AtomFeatures
from hydragnn.preprocess import get_radius_graph_config
from hydragnn.preprocess import update_predicted_values, update_atom_features
from hydragnn.utils.distributed import get_device
from hydragnn.utils.print_utils import print_distributed, iterate_tqdm
from hydragnn.preprocess.utils import (
get_radius_graph,
get_radius_graph_pbc,
get_radius_graph_config,
get_radius_graph_pbc_config,
)


Expand Down Expand Up @@ -141,7 +138,6 @@ def load_serialized_data(self, dataset_path: str):
loop=False,
max_neighbours=self.max_neighbours,
)
compute_edge_lengths = Distance(norm=False, cat=True)

dataset[:] = [compute_edges(data) for data in dataset]

Expand Down Expand Up @@ -188,7 +184,7 @@ def load_serialized_data(self, dataset_path: str):
data,
)

self.__update_atom_features(self.input_node_features, data)
update_atom_features(self.input_node_features, data)

if "subsample_percentage" in self.variables.keys():
self.subsample_percentage = self.variables["subsample_percentage"]
Expand All @@ -198,19 +194,6 @@ def load_serialized_data(self, dataset_path: str):

return dataset

def __update_atom_features(self, atom_features: [AtomFeatures], data: Data):
"""Updates atom features of a structure. An atom is represented with x,y,z coordinates and associated features.
Parameters
----------
atom_features: [AtomFeatures]
List of features to update. Each feature is instance of Enum AtomFeatures.
data: Data
A Data object representing a structure that has atoms.
"""
feature_indices = [i for i in atom_features]
data.x = data.x[:, feature_indices]

def __stratified_sampling(self, dataset: [Data], subsample_percentage: float):
"""Given the dataset and the percentage of data you want to extract from it, method will
apply stratified sampling where X is the dataset and Y is are the category values for each datapoint.
Expand Down Expand Up @@ -257,47 +240,3 @@ def __stratified_sampling(self, dataset: [Data], subsample_percentage: float):
subsample.append(dataset[index])

return subsample


def update_predicted_values(
type: list, index: list, graph_feature_dim: list, node_feature_dim: list, data: Data
):
"""Updates values of the structure we want to predict. Predicted value is represented by integer value.
Parameters
----------
type: "graph" level or "node" level
index: index/location in data.y for graph level and in data.x for node level
graph_feature_dim: list of integers to trak the dimension of each graph level feature
data: Data
A Data object representing a structure that has atoms.
"""
output_feature = []
data.y_loc = torch.zeros(1, len(type) + 1, dtype=torch.int64, device=data.x.device)
for item in range(len(type)):
if type[item] == "graph":
index_counter_global_y = sum(graph_feature_dim[: index[item]])
feat_ = torch.reshape(
data.y[
index_counter_global_y : index_counter_global_y
+ graph_feature_dim[index[item]]
],
(graph_feature_dim[index[item]], 1),
)
# after the global features are spanned, we need to iterate over the nodal features
# to do so, the counter of the nodal features need to start from the last value of counter for the graph nodel feature
elif type[item] == "node":
index_counter_nodal_y = sum(node_feature_dim[: index[item]])
feat_ = torch.reshape(
data.x[
:,
index_counter_nodal_y : (
index_counter_nodal_y + node_feature_dim[index[item]]
),
],
(-1, 1),
)
else:
raise ValueError("Unknown output type", type[item])
output_feature.append(feat_)
data.y_loc[0, item + 1] = data.y_loc[0, item] + feat_.shape[0] * feat_.shape[1]
data.y = torch.cat(output_feature, 0)
61 changes: 61 additions & 0 deletions hydragnn/preprocess/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
import torch
from torch_geometric.transforms import RadiusGraph
from torch_geometric.utils import remove_self_loops, degree
from torch_geometric.data import Data

import ase
import ase.neighborlist
import os

from .dataset_descriptors import AtomFeatures

## This function can be slow if dataset is too large. Use with caution.
## Recommend to use check_if_graph_size_variable_dist
def check_if_graph_size_variable(train_loader, val_loader, test_loader):
Expand Down Expand Up @@ -229,3 +232,61 @@ def gather_deg_mpi(dataset):
deg += torch.bincount(d, minlength=deg.numel())
deg = MPI.COMM_WORLD.allreduce(deg.numpy(), op=MPI.SUM)
return deg


def update_predicted_values(
type: list, index: list, graph_feature_dim: list, node_feature_dim: list, data: Data
):
"""Updates values of the structure we want to predict. Predicted value is represented by integer value.
Parameters
----------
type: "graph" level or "node" level
index: index/location in data.y for graph level and in data.x for node level
graph_feature_dim: list of integers to trak the dimension of each graph level feature
data: Data
A Data object representing a structure that has atoms.
"""
output_feature = []
data.y_loc = torch.zeros(1, len(type) + 1, dtype=torch.int64, device=data.x.device)
for item in range(len(type)):
if type[item] == "graph":
index_counter_global_y = sum(graph_feature_dim[: index[item]])
feat_ = torch.reshape(
data.y[
index_counter_global_y : index_counter_global_y
+ graph_feature_dim[index[item]]
],
(graph_feature_dim[index[item]], 1),
)
# after the global features are spanned, we need to iterate over the nodal features
# to do so, the counter of the nodal features need to start from the last value of counter for the graph nodel feature
elif type[item] == "node":
index_counter_nodal_y = sum(node_feature_dim[: index[item]])
feat_ = torch.reshape(
data.x[
:,
index_counter_nodal_y : (
index_counter_nodal_y + node_feature_dim[index[item]]
),
],
(-1, 1),
)
else:
raise ValueError("Unknown output type", type[item])
output_feature.append(feat_)
data.y_loc[0, item + 1] = data.y_loc[0, item] + feat_.shape[0] * feat_.shape[1]
data.y = torch.cat(output_feature, 0)


def update_atom_features(atom_features: [AtomFeatures], data: Data):
"""Updates atom features of a structure. An atom is represented with x,y,z coordinates and associated features.
Parameters
----------
atom_features: [AtomFeatures]
List of features to update. Each feature is instance of Enum AtomFeatures.
data: Data
A Data object representing a structure that has atoms.
"""
feature_indices = [i for i in atom_features]
data.x = data.x[:, feature_indices]
27 changes: 25 additions & 2 deletions hydragnn/utils/pickledataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from .print_utils import print_distributed, log, iterate_tqdm

from hydragnn.utils.abstractbasedataset import AbstractBaseDataset
from hydragnn.preprocess import update_predicted_values, update_atom_features

import hydragnn.utils.tracer as tr


class SimplePickleDataset(AbstractBaseDataset):
"""Simple Pickle Dataset"""

def __init__(self, basedir, label, subset=None, preload=False):
def __init__(self, basedir, label, subset=None, preload=False, var_config=None):
"""
Parameters
----------
Expand All @@ -28,6 +29,14 @@ def __init__(self, basedir, label, subset=None, preload=False):
self.label = label
self.subset = subset
self.preload = preload
self.var_config = var_config
self.input_node_features = var_config["input_node_features"]

if self.var_config is not None:
self.variables_type = self.var_config["type"]
self.output_index = self.var_config["output_index"]
self.graph_feature_dim = self.var_config["graph_feature_dims"]
self.node_feature_dim = self.var_config["node_feature_dims"]

fname = os.path.join(basedir, "%s-meta.pkl" % label)
with open(fname, "rb") as f:
Expand All @@ -48,7 +57,9 @@ def __init__(self, basedir, label, subset=None, preload=False):

if self.preload:
for i in range(self.ntotal):
self.dataset.append(self.read(i))
data = self.read(i)
self.update_data_object(data)
self.dataset.append(data)

def len(self):
return len(self.subset)
Expand All @@ -75,11 +86,23 @@ def read(self, k):
dirfname = os.path.join(self.basedir, subdir, fname)
with open(dirfname, "rb") as f:
data_object = pickle.load(f)
self.update_data_object(data_object)
return data_object

def setsubset(self, subset):
self.subset = subset

def update_data_object(self, data_object):
if self.var_config is not None:
update_predicted_values(
self.variables_type,
self.output_index,
self.graph_feature_dim,
self.node_feature_dim,
data_object,
)
update_atom_features(self.input_node_features, data_object)


class SimplePickleWriter:
"""SimplePickleWriter class to write Torch Geometric graph data"""
Expand Down

0 comments on commit 10c0b15

Please sign in to comment.