Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sh/datasets #300

Merged
merged 11 commits into from
Jun 11, 2021
10 changes: 10 additions & 0 deletions src/schnetpack/configs/data/ani1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_target_: schnetpack.datasets.ANI1

datapath: ${data_dir}/ani1.db # data_dir is specified in train.yaml
batch_size: 32
num_train: 10000000
num_val: 100000
num_test: null
num_heavy_atoms: 8
high_energies: False
num_workers: 8
4 changes: 2 additions & 2 deletions src/schnetpack/configs/data/custom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ _target_: schnetpack.data.AtomsDataModule

datapath: null # data_dir is specified in train.yaml
batch_size: 32
num_train: 10000
num_val: 1000
num_train: 0.8
num_val: 0.1
num_test: null
num_workers: 8
9 changes: 9 additions & 0 deletions src/schnetpack/configs/data/iso17.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: schnetpack.datasets.ISO17

datapath: ${data_dir}/${data.folder}.db # data_dir is specified in train.yaml
folder: reference
batch_size: 32
num_train: 0.9
num_val: 0.1
num_test: null
num_workers: 8
9 changes: 9 additions & 0 deletions src/schnetpack/configs/data/materials_project.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: schnetpack.datasets.MaterialsProject

datapath: ${data_dir}/materials_project.db # data_dir is specified in train.yaml
batch_size: 32
num_train: 60000
num_val: 2000
num_test: null
num_workers: 8
apikey: null
9 changes: 9 additions & 0 deletions src/schnetpack/configs/data/omdb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: schnetpack.datasets.OrganicMaterialsDatabase

datapath: ${data_dir}/omdb.db # data_dir is specified in train.yaml
batch_size: 32
num_train: 0.8
num_val: 0.1
num_test: null
num_workers: 8
raw_path: null
31 changes: 31 additions & 0 deletions src/schnetpack/configs/experiment/ani1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# @package _global_

defaults:
- override /model: single_property
- override /model/representation: schnet
- override /data: ani1
- override /logger: tensorboard

cutoff: 5.0
n_rbf: 20

lr: 1e-3

property: energy
name: ani1_${property}


data:
distance_unit: Ang
property_units:
energy: hartree
transforms:
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack.transform.RemoveOffsets
property: ${property}
remove_atomrefs: True
remove_mean: True
- _target_: schnetpack.transform.TorchNeighborList
cutoff: ${cutoff}
- _target_: schnetpack.transform.CastTo32

37 changes: 37 additions & 0 deletions src/schnetpack/configs/experiment/iso17.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# @package _global_

defaults:
- override /model: pes
- override /data: iso17
- override /logger: tensorboard

cutoff: 5.0
n_rbf: 20
lr: 1e-4

name: md17_${data.molecule}

data:
distance_unit: Ang
property_units:
energy: eV
forces: eV/Ang
transforms:
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack.transform.RemoveOffsets
property: total_energy
remove_mean: True
- _target_: schnetpack.transform.TorchNeighborList
cutoff: ${cutoff}
- _target_: schnetpack.transform.CastTo32


model:
energy_property: total_energy
forces_property: atomic_forces
postprocess:
- _target_: schnetpack.transform.CastTo64
- _target_: schnetpack.transform.AddOffsets
property: total_energy
add_mean: True

31 changes: 31 additions & 0 deletions src/schnetpack/configs/experiment/materials_project.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# @package _global_

defaults:
- override /model: single_property
- override /model/representation: schnet
- override /data: materials_project
- override /logger: tensorboard

cutoff: 5.0
n_rbf: 20

lr: 1e-3

property: formation_energy_per_atom
name: mp_${property}


data:
distance_unit: Ang
property_units:
energy: eV
transforms:
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack.transform.RemoveOffsets
property: ${property}
remove_atomrefs: True
remove_mean: True
- _target_: schnetpack.transform.TorchNeighborList
cutoff: ${cutoff}
- _target_: schnetpack.transform.CastTo32

31 changes: 31 additions & 0 deletions src/schnetpack/configs/experiment/omdb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# @package _global_

defaults:
- override /model: single_property
- override /model/representation: schnet
- override /data: omdb
- override /logger: tensorboard

cutoff: 5.0
n_rbf: 20

lr: 1e-3

property: band_gap
name: mp_${property}


data:
distance_unit: Ang
property_units:
energy: eV
transforms:
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack.transform.RemoveOffsets
property: ${property}
remove_atomrefs: True
remove_mean: True
- _target_: schnetpack.transform.TorchNeighborList
cutoff: ${cutoff}
- _target_: schnetpack.transform.CastTo32

40 changes: 29 additions & 11 deletions src/schnetpack/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from copy import copy
from typing import Optional, List, Dict, Tuple
from typing import Optional, List, Dict, Tuple, Union

import numpy as np
import pytorch_lightning as pl
Expand Down Expand Up @@ -31,9 +31,9 @@ def __init__(
self,
datapath: str,
batch_size: int,
num_train: int = None,
num_val: int = None,
num_test: int = None,
num_train: Union[int, float] = None,
num_val: Union[int, float] = None,
num_test: Optional[Union[int, float]] = None,
split_file: Optional[str] = "split.npz",
format: Optional[AtomsDataFormat] = None,
load_properties: Optional[List[str]] = None,
Expand All @@ -53,9 +53,9 @@ def __init__(
Args:
datapath: path to dataset
batch_size: (train) batch size
num_train: number of training examples
num_val: number of validation examples
num_test: number of test examples
num_train: number of training examples (absolute or relative)
num_val: number of validation examples (absolute or relative)
num_test: number of test examples (absolute or relative)
split_file: path to npz file with data partitions
format: dataset format
load_properties: subset of properties to load
Expand Down Expand Up @@ -117,8 +117,22 @@ def load_data(self):
)

def partition(self):
# split dataset
# TODO: handle IterDatasets
# handle relative dataset sizes
if self.num_train is not None and self.num_train <= 1.0:
self.num_train = round(self.num_train * len(self.dataset))
if self.num_val is not None and self.num_val <= 1.0:
self.num_val = min(
round(self.num_val * len(self.dataset)),
len(self.dataset) - self.num_train,
)
if self.num_test is not None and self.num_test <= 1.0:
self.num_test = min(
round(self.num_test * len(self.dataset)),
len(self.dataset) - self.num_train - self.num_val,
)

# split dataset
if self.split_file is not None and os.path.exists(self.split_file):
S = np.load(self.split_file)
train_idx = S["train_idx"].tolist()
Expand Down Expand Up @@ -152,9 +166,13 @@ def partition(self):
indices[offset - length : offset]
for offset, length in zip(offsets, lengths)
]
np.savez(
self.split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx
)
if self.split_file is not None:
np.savez(
self.split_file,
train_idx=train_idx,
val_idx=val_idx,
test_idx=test_idx,
)
self._train_dataset = self.dataset.subset(train_idx)
self._val_dataset = self.dataset.subset(val_idx)
self._test_dataset = self.dataset.subset(test_idx)
Expand Down
4 changes: 4 additions & 0 deletions src/schnetpack/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from .qm9 import *
from .md17 import *
from .iso17 import *
from .ani1 import *
from .materials_project import *
from .omdb import *
Loading