Skip to content

Commit

Permalink
restructure pre-process data utils (#192)
Browse files Browse the repository at this point in the history
* remove update_atom_features from method and use implementation of the same function from utils

* function stratified_sampling moved to preprocess\utils

* formatting fixed
  • Loading branch information
allaffa authored Sep 21, 2023
1 parent ae19713 commit defda39
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 59 deletions.
1 change: 1 addition & 0 deletions hydragnn/preprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
RadiusGraphPBC,
update_predicted_values,
update_atom_features,
stratified_sampling,
)

from .load_data import (
Expand Down
44 changes: 44 additions & 0 deletions hydragnn/preprocess/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,47 @@ def update_atom_features(atom_features: [AtomFeatures], data: Data):
"""
feature_indices = [i for i in atom_features]
data.x = data.x[:, feature_indices]


def stratified_sampling(dataset: [Data], subsample_percentage: float, verbosity=0):
"""Given the dataset and the percentage of data you want to extract from it, method will
apply stratified sampling where X is the dataset and Y is are the category values for each datapoint.
In the case of the structures dataset where each structure contains 2 types of atoms, the category will
be constructed in a way: number of atoms of type 1 + number of protons of type 2 * 100.
Parameters
----------
dataset: [Data]
A list of Data objects representing a structure that has atoms.
subsample_percentage: float
Percentage of the dataset.
Returns
----------
[Data]
Subsample of the original dataset constructed using stratified sampling.
"""
dataset_categories = []
print_distributed(verbosity, "Computing the categories for the whole dataset.")
for data in iterate_tqdm(dataset, verbosity):
frequencies = torch.bincount(data.x[:, 0].int())
frequencies = sorted(frequencies[frequencies > 0].tolist())
category = 0
for index, frequency in enumerate(frequencies):
category += frequency * (100 ** index)
dataset_categories.append(category)

subsample_indices = []
subsample = []

sss = StratifiedShuffleSplit(
n_splits=1, train_size=subsample_percentage, random_state=0
)

for subsample_index, rest_of_data_index in sss.split(dataset, dataset_categories):
subsample_indices = subsample_index.tolist()

for index in subsample_indices:
subsample.append(dataset[index])

return subsample
65 changes: 6 additions & 59 deletions hydragnn/utils/abstractrawdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
get_radius_graph_config,
get_radius_graph_pbc_config,
)
from hydragnn.preprocess import update_predicted_values
from hydragnn.preprocess import (
update_predicted_values,
update_atom_features,
stratified_sampling,
)

from sklearn.model_selection import StratifiedShuffleSplit

Expand Down Expand Up @@ -317,19 +321,6 @@ def __scale_features_by_num_nodes(self):
/ data_object.num_nodes
)

def __update_atom_features(self, atom_features: [AtomFeatures], data: Data):
"""Updates atom features of a structure. An atom is represented with x,y,z coordinates and associated features.
Parameters
----------
atom_features: [AtomFeatures]
List of features to update. Each feature is instance of Enum AtomFeatures.
data: Data
A Data object representing a structure that has atoms.
"""
feature_indices = [i for i in atom_features]
data.x = data.x[:, feature_indices]

def __build_edge(self):
"""Loads the serialized structures data from specified path, computes new edges for the structures based on the maximum number of neighbours and radius. Additionally,
atom and structure features are updated.
Expand Down Expand Up @@ -411,7 +402,7 @@ def __build_edge(self):
data,
)

self.__update_atom_features(self.input_node_features, data)
update_atom_features(self.input_node_features, data)

if "subsample_percentage" in self.variables.keys():
self.subsample_percentage = self.variables["subsample_percentage"]
Expand All @@ -425,47 +416,3 @@ def len(self):

def get(self, idx):
return self.dataset[idx]


def stratified_sampling(dataset: [Data], subsample_percentage: float, verbosity=0):
"""Given the dataset and the percentage of data you want to extract from it, method will
apply stratified sampling where X is the dataset and Y is are the category values for each datapoint.
In the case of the structures dataset where each structure contains 2 types of atoms, the category will
be constructed in a way: number of atoms of type 1 + number of protons of type 2 * 100.
Parameters
----------
dataset: [Data]
A list of Data objects representing a structure that has atoms.
subsample_percentage: float
Percentage of the dataset.
Returns
----------
[Data]
Subsample of the original dataset constructed using stratified sampling.
"""
dataset_categories = []
print_distributed(verbosity, "Computing the categories for the whole dataset.")
for data in iterate_tqdm(dataset, verbosity):
frequencies = torch.bincount(data.x[:, 0].int())
frequencies = sorted(frequencies[frequencies > 0].tolist())
category = 0
for index, frequency in enumerate(frequencies):
category += frequency * (100 ** index)
dataset_categories.append(category)

subsample_indices = []
subsample = []

sss = StratifiedShuffleSplit(
n_splits=1, train_size=subsample_percentage, random_state=0
)

for subsample_index, rest_of_data_index in sss.split(dataset, dataset_categories):
subsample_indices = subsample_index.tolist()

for index in subsample_indices:
subsample.append(dataset[index])

return subsample

0 comments on commit defda39

Please sign in to comment.