Skip to content

Commit

Permalink
first draft of MultipleTrainEpochsSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsLs committed Aug 8, 2024
1 parent 643c9a1 commit 63c9d11
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion src/schnetpack/data/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Iterator, List, Callable

import numpy as np
from torch.utils.data import Sampler, WeightedRandomSampler
from torch.utils.data import Sampler, WeightedRandomSampler, RandomSampler

from schnetpack import properties
from schnetpack.data import BaseAtomsData
Expand All @@ -11,6 +11,7 @@
"StratifiedSampler",
"NumberOfAtomsCriterion",
"PropertyCriterion",
"MultipleTrainEpochsSampler",
]


Expand Down Expand Up @@ -95,3 +96,19 @@ def calculate_weights(self, partition_criterion):
weights = bin_weights[bin_indices]

return weights


class MultipleTrainEpochsSampler(RandomSampler):
def __init__(
self,
data_source,
num_samples=None,
n_train_epochs=1,
generator=None,
):
super().__init__(
data_source=data_source,
replacement=True,
num_samples=len(data_source) * n_train_epochs,
generator=generator,
)

0 comments on commit 63c9d11

Please sign in to comment.