diff --git a/hydragnn/preprocess/__init__.py b/hydragnn/preprocess/__init__.py index ee6513ee7..d0d52836d 100644 --- a/hydragnn/preprocess/__init__.py +++ b/hydragnn/preprocess/__init__.py @@ -8,6 +8,8 @@ get_radius_graph_config, get_radius_graph_pbc_config, RadiusGraphPBC, + update_predicted_values, + update_atom_features ) from .load_data import ( @@ -20,7 +22,6 @@ ) from .serialized_dataset_loader import ( SerializedDataLoader, - update_predicted_values, ) from .lsms_raw_dataset_loader import LSMS_RawDataLoader from .cfg_raw_dataset_loader import CFG_RawDataLoader diff --git a/hydragnn/preprocess/serialized_dataset_loader.py b/hydragnn/preprocess/serialized_dataset_loader.py index 6d3028ea4..3b1e60e1c 100644 --- a/hydragnn/preprocess/serialized_dataset_loader.py +++ b/hydragnn/preprocess/serialized_dataset_loader.py @@ -21,15 +21,12 @@ PointPairFeatures, ) -from .dataset_descriptors import AtomFeatures -from hydragnn.preprocess import get_radius_graph_config +from hydragnn.preprocess import update_predicted_values, update_atom_features from hydragnn.utils.distributed import get_device from hydragnn.utils.print_utils import print_distributed, iterate_tqdm from hydragnn.preprocess.utils import ( get_radius_graph, get_radius_graph_pbc, - get_radius_graph_config, - get_radius_graph_pbc_config, ) @@ -141,7 +138,6 @@ def load_serialized_data(self, dataset_path: str): loop=False, max_neighbours=self.max_neighbours, ) - compute_edge_lengths = Distance(norm=False, cat=True) dataset[:] = [compute_edges(data) for data in dataset] @@ -188,7 +184,7 @@ def load_serialized_data(self, dataset_path: str): data, ) - self.__update_atom_features(self.input_node_features, data) + update_atom_features(self.input_node_features, data) if "subsample_percentage" in self.variables.keys(): self.subsample_percentage = self.variables["subsample_percentage"] @@ -198,19 +194,6 @@ def load_serialized_data(self, dataset_path: str): return dataset - def __update_atom_features(self, atom_features: [AtomFeatures], data: Data): - """Updates atom features of a structure. An atom is represented with x,y,z coordinates and associated features. - - Parameters - ---------- - atom_features: [AtomFeatures] - List of features to update. Each feature is instance of Enum AtomFeatures. - data: Data - A Data object representing a structure that has atoms. - """ - feature_indices = [i for i in atom_features] - data.x = data.x[:, feature_indices] - def __stratified_sampling(self, dataset: [Data], subsample_percentage: float): """Given the dataset and the percentage of data you want to extract from it, method will apply stratified sampling where X is the dataset and Y is are the category values for each datapoint. @@ -257,47 +240,3 @@ def __stratified_sampling(self, dataset: [Data], subsample_percentage: float): subsample.append(dataset[index]) return subsample - - -def update_predicted_values( - type: list, index: list, graph_feature_dim: list, node_feature_dim: list, data: Data -): - """Updates values of the structure we want to predict. Predicted value is represented by integer value. - Parameters - ---------- - type: "graph" level or "node" level - index: index/location in data.y for graph level and in data.x for node level - graph_feature_dim: list of integers to trak the dimension of each graph level feature - data: Data - A Data object representing a structure that has atoms. - """ - output_feature = [] - data.y_loc = torch.zeros(1, len(type) + 1, dtype=torch.int64, device=data.x.device) - for item in range(len(type)): - if type[item] == "graph": - index_counter_global_y = sum(graph_feature_dim[: index[item]]) - feat_ = torch.reshape( - data.y[ - index_counter_global_y : index_counter_global_y - + graph_feature_dim[index[item]] - ], - (graph_feature_dim[index[item]], 1), - ) - # after the global features are spanned, we need to iterate over the nodal features - # to do so, the counter of the nodal features need to start from the last value of counter for the graph nodel feature - elif type[item] == "node": - index_counter_nodal_y = sum(node_feature_dim[: index[item]]) - feat_ = torch.reshape( - data.x[ - :, - index_counter_nodal_y : ( - index_counter_nodal_y + node_feature_dim[index[item]] - ), - ], - (-1, 1), - ) - else: - raise ValueError("Unknown output type", type[item]) - output_feature.append(feat_) - data.y_loc[0, item + 1] = data.y_loc[0, item] + feat_.shape[0] * feat_.shape[1] - data.y = torch.cat(output_feature, 0) diff --git a/hydragnn/preprocess/utils.py b/hydragnn/preprocess/utils.py index 6f5a08e00..8784fcb4b 100644 --- a/hydragnn/preprocess/utils.py +++ b/hydragnn/preprocess/utils.py @@ -12,11 +12,14 @@ import torch from torch_geometric.transforms import RadiusGraph from torch_geometric.utils import remove_self_loops, degree +from torch_geometric.data import Data import ase import ase.neighborlist import os +from .dataset_descriptors import AtomFeatures + ## This function can be slow if dataset is too large. Use with caution. ## Recommend to use check_if_graph_size_variable_dist def check_if_graph_size_variable(train_loader, val_loader, test_loader): @@ -229,3 +232,61 @@ def gather_deg_mpi(dataset): deg += torch.bincount(d, minlength=deg.numel()) deg = MPI.COMM_WORLD.allreduce(deg.numpy(), op=MPI.SUM) return deg + + +def update_predicted_values( + type: list, index: list, graph_feature_dim: list, node_feature_dim: list, data: Data +): + """Updates values of the structure we want to predict. Predicted value is represented by integer value. + Parameters + ---------- + type: "graph" level or "node" level + index: index/location in data.y for graph level and in data.x for node level + graph_feature_dim: list of integers to trak the dimension of each graph level feature + data: Data + A Data object representing a structure that has atoms. + """ + output_feature = [] + data.y_loc = torch.zeros(1, len(type) + 1, dtype=torch.int64, device=data.x.device) + for item in range(len(type)): + if type[item] == "graph": + index_counter_global_y = sum(graph_feature_dim[: index[item]]) + feat_ = torch.reshape( + data.y[ + index_counter_global_y : index_counter_global_y + + graph_feature_dim[index[item]] + ], + (graph_feature_dim[index[item]], 1), + ) + # after the global features are spanned, we need to iterate over the nodal features + # to do so, the counter of the nodal features need to start from the last value of counter for the graph nodel feature + elif type[item] == "node": + index_counter_nodal_y = sum(node_feature_dim[: index[item]]) + feat_ = torch.reshape( + data.x[ + :, + index_counter_nodal_y : ( + index_counter_nodal_y + node_feature_dim[index[item]] + ), + ], + (-1, 1), + ) + else: + raise ValueError("Unknown output type", type[item]) + output_feature.append(feat_) + data.y_loc[0, item + 1] = data.y_loc[0, item] + feat_.shape[0] * feat_.shape[1] + data.y = torch.cat(output_feature, 0) + + +def update_atom_features(atom_features: [AtomFeatures], data: Data): + """Updates atom features of a structure. An atom is represented with x,y,z coordinates and associated features. + + Parameters + ---------- + atom_features: [AtomFeatures] + List of features to update. Each feature is instance of Enum AtomFeatures. + data: Data + A Data object representing a structure that has atoms. + """ + feature_indices = [i for i in atom_features] + data.x = data.x[:, feature_indices] diff --git a/hydragnn/utils/pickledataset.py b/hydragnn/utils/pickledataset.py index a1fce3dc5..7b8ae0c05 100644 --- a/hydragnn/utils/pickledataset.py +++ b/hydragnn/utils/pickledataset.py @@ -7,6 +7,7 @@ from .print_utils import print_distributed, log, iterate_tqdm from hydragnn.utils.abstractbasedataset import AbstractBaseDataset +from hydragnn.preprocess import update_predicted_values, update_atom_features import hydragnn.utils.tracer as tr @@ -14,7 +15,7 @@ class SimplePickleDataset(AbstractBaseDataset): """Simple Pickle Dataset""" - def __init__(self, basedir, label, subset=None, preload=False): + def __init__(self, basedir, label, subset=None, preload=False, var_config=None): """ Parameters ---------- @@ -28,6 +29,14 @@ def __init__(self, basedir, label, subset=None, preload=False): self.label = label self.subset = subset self.preload = preload + self.var_config = var_config + self.input_node_features = var_config["input_node_features"] + + if self.var_config is not None: + self.variables_type = self.var_config["type"] + self.output_index = self.var_config["output_index"] + self.graph_feature_dim = self.var_config["graph_feature_dims"] + self.node_feature_dim = self.var_config["node_feature_dims"] fname = os.path.join(basedir, "%s-meta.pkl" % label) with open(fname, "rb") as f: @@ -48,7 +57,9 @@ def __init__(self, basedir, label, subset=None, preload=False): if self.preload: for i in range(self.ntotal): - self.dataset.append(self.read(i)) + data = self.read(i) + self.update_data_object(data) + self.dataset.append(data) def len(self): return len(self.subset) @@ -75,11 +86,23 @@ def read(self, k): dirfname = os.path.join(self.basedir, subdir, fname) with open(dirfname, "rb") as f: data_object = pickle.load(f) + self.update_data_object(data_object) return data_object def setsubset(self, subset): self.subset = subset + def update_data_object(self, data_object): + if self.var_config is not None: + update_predicted_values( + self.variables_type, + self.output_index, + self.graph_feature_dim, + self.node_feature_dim, + data_object, + ) + update_atom_features(self.input_node_features, data_object) + class SimplePickleWriter: """SimplePickleWriter class to write Torch Geometric graph data"""