Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling random generator #2309

Merged
merged 10 commits into from
Sep 23, 2024
10 changes: 10 additions & 0 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -144,6 +145,15 @@ def test_weighted_sampling(self) -> None:
for bbox in batch:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
29 changes: 29 additions & 0 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -139,6 +140,15 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sampler2 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -288,6 +298,25 @@ def test_point_data(self) -> None:
for _ in sampler:
continue

def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
generator = torch.manual_seed(2)
sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler1:
sample1 = bbox
print(sample1)
break
sfalkena marked this conversation as resolved.
Show resolved Hide resolved

generator = torch.manual_seed(2)
sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler2:
sample2 = bbox
print(sample2)
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
6 changes: 5 additions & 1 deletion torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None:

if stage in ['fit']:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
self.train_dataset,
self.patch_size,
self.batch_size,
self.length,
generator=generator,
)
if stage in ['fit', 'validate']:
self.val_sampler = GridGeoSampler(
Expand Down
11 changes: 10 additions & 1 deletion torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from rtree.index import Index, Property
from torch import Generator
from torch.utils.data import Sampler

from ..datasets import BoundingBox, GeoDataset
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.

Expand All @@ -86,6 +88,9 @@ def __init__(
.. versionchanged:: 0.4
``length`` parameter is now optional, a reasonable default will be used

.. versionadded:: 0.7
The *generator* parameter.
sfalkena marked this conversation as resolved.
Show resolved Hide resolved

Args:
dataset: dataset to index from
size: dimensions of each :term:`patch`
Expand All @@ -97,9 +102,11 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: pseudo-random number generator (PRNG).
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.generator = generator

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)
Expand Down Expand Up @@ -144,7 +151,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]:
# Choose random indices within that tile
batch = []
for _ in range(self.batch_size):
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)
batch.append(bounding_box)

yield batch
Expand Down
24 changes: 21 additions & 3 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import abc
from collections.abc import Callable, Iterable, Iterator
from functools import partial

import torch
from rtree.index import Index, Property
from torch import Generator
from torch.utils.data import Sampler

from ..datasets import BoundingBox, GeoDataset
Expand Down Expand Up @@ -72,6 +74,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.

Expand All @@ -88,6 +91,9 @@ def __init__(
.. versionchanged:: 0.4
``length`` parameter is now optional, a reasonable default will be used

.. versionadded:: 0.7
The *generator* parameter.

Args:
dataset: dataset to index from
size: dimensions of each :term:`patch`
Expand All @@ -98,13 +104,16 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: pseudo-random number generator (PRNG).

"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)

self.generator = generator
self.length = 0
self.hits = []
areas = []
Expand Down Expand Up @@ -142,7 +151,9 @@ def __iter__(self) -> Iterator[BoundingBox]:
bounds = BoundingBox(*hit.bounds)

# Choose a random index within that tile
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)

yield bounding_box

Expand Down Expand Up @@ -270,7 +281,11 @@ class PreChippedGeoSampler(GeoSampler):
"""

def __init__(
self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False
self,
dataset: GeoDataset,
roi: BoundingBox | None = None,
shuffle: bool = False,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.

Expand All @@ -281,9 +296,12 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
shuffle: if True, reshuffle data at every epoch
generator: The random number generator used in combination with shuffle.

sfalkena marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__(dataset, roi)
self.shuffle = shuffle
self.generator = generator

self.hits = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
Expand All @@ -297,7 +315,7 @@ def __iter__(self) -> Iterator[BoundingBox]:
"""
generator: Callable[[int], Iterable[int]] = range
if self.shuffle:
generator = torch.randperm
generator = partial(torch.randperm, generator=self.generator)

for idx in generator(len(self)):
yield BoundingBox(*self.hits[idx].bounds)
Expand Down
14 changes: 11 additions & 3 deletions torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import overload

import torch
from torch import Generator

from ..datasets import BoundingBox

Expand Down Expand Up @@ -35,7 +36,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]:


def get_random_bounding_box(
bounds: BoundingBox, size: tuple[float, float] | float, res: float
bounds: BoundingBox,
size: tuple[float, float] | float,
res: float,
generator: Generator | None = None,
) -> BoundingBox:
"""Returns a random bounding box within a given bounding box.

Expand All @@ -46,10 +50,14 @@ def get_random_bounding_box(
* a ``tuple`` of two floats - in which case, the first *float* is used for the
height dimension, and the second *float* for the width dimension

.. versionadded:: 0.7
The *generator* parameter.

Args:
bounds: the larger bounding box to sample from
size: the size of the bounding box to sample
res: the resolution of the image
generator: pseudo-random number generator (PRNG).

Returns:
randomly sampled bounding box from the extent of the input
Expand All @@ -64,8 +72,8 @@ def get_random_bounding_box(
miny = bounds.miny

# Use an integer multiple of res to avoid resampling
minx += int(torch.rand(1).item() * width) * res
miny += int(torch.rand(1).item() * height) * res
minx += int(torch.rand(1, generator=generator).item() * width) * res
miny += int(torch.rand(1, generator=generator).item() * height) * res

maxx = minx + t_size[1]
maxy = miny + t_size[0]
Expand Down