Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utilies to update data objects moved to utils #190

Merged
merged 8 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hydragnn/preprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
get_radius_graph_config,
get_radius_graph_pbc_config,
RadiusGraphPBC,
update_predicted_values,
update_atom_features,
)

from .load_data import (
Expand All @@ -20,7 +22,6 @@
)
from .serialized_dataset_loader import (
SerializedDataLoader,
update_predicted_values,
)
from .lsms_raw_dataset_loader import LSMS_RawDataLoader
from .cfg_raw_dataset_loader import CFG_RawDataLoader
Expand Down
76 changes: 7 additions & 69 deletions hydragnn/preprocess/serialized_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@
PointPairFeatures,
)

from .dataset_descriptors import AtomFeatures
from hydragnn.preprocess import get_radius_graph_config
from hydragnn.preprocess import update_predicted_values, update_atom_features
from hydragnn.utils.distributed import get_device
from hydragnn.utils.print_utils import print_distributed, iterate_tqdm
from hydragnn.preprocess.utils import (
get_radius_graph,
get_radius_graph_pbc,
get_radius_graph_config,
get_radius_graph_pbc_config,
)


Expand Down Expand Up @@ -118,7 +115,6 @@ def load_serialized_data(self, dataset_path: str):
[Data]
List of Data objects representing atom structures.
"""
dataset = []
with open(dataset_path, "rb") as f:
_ = pickle.load(f)
_ = pickle.load(f)
Expand All @@ -141,7 +137,6 @@ def load_serialized_data(self, dataset_path: str):
loop=False,
max_neighbours=self.max_neighbours,
)
compute_edge_lengths = Distance(norm=False, cat=True)

dataset[:] = [compute_edges(data) for data in dataset]

Expand Down Expand Up @@ -169,11 +164,11 @@ def load_serialized_data(self, dataset_path: str):
data.edge_attr = data.edge_attr / max_edge_length

# Descriptors about topology of the local environment
for data in dataset:
if self.spherical_coordinates:
data = Spherical(data)
if self.point_pair_features:
data = PointPairFeatures(data)
if self.spherical_coordinates:
self.dataset[:] = [Spherical(data) for data in self.dataset]

if self.point_pair_features:
self.dataset[:] = [PointPairFeatures(data) for data in self.dataset]

# Move data to the device, if used. # FIXME: this does not respect the choice set by use_gpu
device = get_device(verbosity_level=self.verbosity)
Expand All @@ -188,7 +183,7 @@ def load_serialized_data(self, dataset_path: str):
data,
)

self.__update_atom_features(self.input_node_features, data)
update_atom_features(self.input_node_features, data)

if "subsample_percentage" in self.variables.keys():
self.subsample_percentage = self.variables["subsample_percentage"]
Expand All @@ -198,19 +193,6 @@ def load_serialized_data(self, dataset_path: str):

return dataset

def __update_atom_features(self, atom_features: [AtomFeatures], data: Data):
"""Updates atom features of a structure. An atom is represented with x,y,z coordinates and associated features.

Parameters
----------
atom_features: [AtomFeatures]
List of features to update. Each feature is instance of Enum AtomFeatures.
data: Data
A Data object representing a structure that has atoms.
"""
feature_indices = [i for i in atom_features]
data.x = data.x[:, feature_indices]

def __stratified_sampling(self, dataset: [Data], subsample_percentage: float):
"""Given the dataset and the percentage of data you want to extract from it, method will
apply stratified sampling where X is the dataset and Y is are the category values for each datapoint.
Expand Down Expand Up @@ -257,47 +239,3 @@ def __stratified_sampling(self, dataset: [Data], subsample_percentage: float):
subsample.append(dataset[index])

return subsample


def update_predicted_values(
type: list, index: list, graph_feature_dim: list, node_feature_dim: list, data: Data
):
"""Updates values of the structure we want to predict. Predicted value is represented by integer value.
Parameters
----------
type: "graph" level or "node" level
index: index/location in data.y for graph level and in data.x for node level
graph_feature_dim: list of integers to trak the dimension of each graph level feature
data: Data
A Data object representing a structure that has atoms.
"""
output_feature = []
data.y_loc = torch.zeros(1, len(type) + 1, dtype=torch.int64, device=data.x.device)
for item in range(len(type)):
if type[item] == "graph":
index_counter_global_y = sum(graph_feature_dim[: index[item]])
feat_ = torch.reshape(
data.y[
index_counter_global_y : index_counter_global_y
+ graph_feature_dim[index[item]]
],
(graph_feature_dim[index[item]], 1),
)
# after the global features are spanned, we need to iterate over the nodal features
# to do so, the counter of the nodal features need to start from the last value of counter for the graph nodel feature
elif type[item] == "node":
index_counter_nodal_y = sum(node_feature_dim[: index[item]])
feat_ = torch.reshape(
data.x[
:,
index_counter_nodal_y : (
index_counter_nodal_y + node_feature_dim[index[item]]
),
],
(-1, 1),
)
else:
raise ValueError("Unknown output type", type[item])
output_feature.append(feat_)
data.y_loc[0, item + 1] = data.y_loc[0, item] + feat_.shape[0] * feat_.shape[1]
data.y = torch.cat(output_feature, 0)
61 changes: 61 additions & 0 deletions hydragnn/preprocess/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
import torch
from torch_geometric.transforms import RadiusGraph
from torch_geometric.utils import remove_self_loops, degree
from torch_geometric.data import Data

import ase
import ase.neighborlist
import os

from .dataset_descriptors import AtomFeatures

## This function can be slow if dataset is too large. Use with caution.
## Recommend to use check_if_graph_size_variable_dist
def check_if_graph_size_variable(train_loader, val_loader, test_loader):
Expand Down Expand Up @@ -229,3 +232,61 @@ def gather_deg_mpi(dataset):
deg += torch.bincount(d, minlength=deg.numel())
deg = MPI.COMM_WORLD.allreduce(deg.numpy(), op=MPI.SUM)
return deg


def update_predicted_values(
type: list, index: list, graph_feature_dim: list, node_feature_dim: list, data: Data
):
"""Updates values of the structure we want to predict. Predicted value is represented by integer value.
Parameters
----------
type: "graph" level or "node" level
index: index/location in data.y for graph level and in data.x for node level
graph_feature_dim: list of integers to trak the dimension of each graph level feature
data: Data
A Data object representing a structure that has atoms.
"""
output_feature = []
data.y_loc = torch.zeros(1, len(type) + 1, dtype=torch.int64, device=data.x.device)
for item in range(len(type)):
if type[item] == "graph":
index_counter_global_y = sum(graph_feature_dim[: index[item]])
feat_ = torch.reshape(
data.y[
index_counter_global_y : index_counter_global_y
+ graph_feature_dim[index[item]]
],
(graph_feature_dim[index[item]], 1),
)
# after the global features are spanned, we need to iterate over the nodal features
# to do so, the counter of the nodal features need to start from the last value of counter for the graph nodel feature
elif type[item] == "node":
index_counter_nodal_y = sum(node_feature_dim[: index[item]])
feat_ = torch.reshape(
data.x[
:,
index_counter_nodal_y : (
index_counter_nodal_y + node_feature_dim[index[item]]
),
],
(-1, 1),
)
else:
raise ValueError("Unknown output type", type[item])
output_feature.append(feat_)
data.y_loc[0, item + 1] = data.y_loc[0, item] + feat_.shape[0] * feat_.shape[1]
data.y = torch.cat(output_feature, 0)


def update_atom_features(atom_features: [AtomFeatures], data: Data):
"""Updates atom features of a structure. An atom is represented with x,y,z coordinates and associated features.

Parameters
----------
atom_features: [AtomFeatures]
List of features to update. Each feature is instance of Enum AtomFeatures.
data: Data
A Data object representing a structure that has atoms.
"""
feature_indices = [i for i in atom_features]
data.x = data.x[:, feature_indices]
6 changes: 3 additions & 3 deletions hydragnn/utils/abstractrawdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_radius_graph_config,
get_radius_graph_pbc_config,
)
from hydragnn.preprocess.serialized_dataset_loader import update_predicted_values
from hydragnn.preprocess import update_predicted_values

from sklearn.model_selection import StratifiedShuffleSplit

Expand Down Expand Up @@ -374,7 +374,7 @@ def __build_edge(self):
# edge lengths already added manually if using PBC.
# if spherical coordinates or pair point is set up, then skip directly to edge_transformation
if (not self.periodic_boundary_conditions) and (
not hasattr(self, self.edge_feature_transform)
not hasattr(self, "edge_feature_transform")
):
self.dataset[:] = [compute_edge_lengths(data) for data in self.dataset]

Expand All @@ -397,7 +397,7 @@ def __build_edge(self):
data.edge_attr = data.edge_attr / max_edge_length

# Descriptors about topology of the local environment
elif hasattr(self, self.edge_feature_transform):
elif hasattr(self, "edge_feature_transform"):
self.dataset[:] = [
self.edge_feature_transform(data) for data in self.dataset
]
Expand Down
27 changes: 25 additions & 2 deletions hydragnn/utils/pickledataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from .print_utils import print_distributed, log, iterate_tqdm

from hydragnn.utils.abstractbasedataset import AbstractBaseDataset
from hydragnn.preprocess import update_predicted_values, update_atom_features

import hydragnn.utils.tracer as tr


class SimplePickleDataset(AbstractBaseDataset):
"""Simple Pickle Dataset"""

def __init__(self, basedir, label, subset=None, preload=False):
def __init__(self, basedir, label, subset=None, preload=False, var_config=None):
"""
Parameters
----------
Expand All @@ -28,6 +29,14 @@ def __init__(self, basedir, label, subset=None, preload=False):
self.label = label
self.subset = subset
self.preload = preload
self.var_config = var_config
self.input_node_features = var_config["input_node_features"]

if self.var_config is not None:
self.variables_type = self.var_config["type"]
self.output_index = self.var_config["output_index"]
self.graph_feature_dim = self.var_config["graph_feature_dims"]
self.node_feature_dim = self.var_config["node_feature_dims"]

fname = os.path.join(basedir, "%s-meta.pkl" % label)
with open(fname, "rb") as f:
Expand All @@ -48,7 +57,9 @@ def __init__(self, basedir, label, subset=None, preload=False):

if self.preload:
for i in range(self.ntotal):
self.dataset.append(self.read(i))
data = self.read(i)
self.update_data_object(data)
self.dataset.append(data)

def len(self):
return len(self.subset)
Expand All @@ -75,11 +86,23 @@ def read(self, k):
dirfname = os.path.join(self.basedir, subdir, fname)
with open(dirfname, "rb") as f:
data_object = pickle.load(f)
self.update_data_object(data_object)
return data_object

def setsubset(self, subset):
self.subset = subset

def update_data_object(self, data_object):
if self.var_config is not None:
update_predicted_values(
self.variables_type,
self.output_index,
self.graph_feature_dim,
self.node_feature_dim,
data_object,
)
update_atom_features(self.input_node_features, data_object)


class SimplePickleWriter:
"""SimplePickleWriter class to write Torch Geometric graph data"""
Expand Down
Loading
Loading