From 494fbd76672b4a29e0a6f8736b604369346d8bdd Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Fri, 20 Sep 2024 19:43:38 +0200 Subject: [PATCH 1/7] Add random generator --- torchgeo/datamodules/agrifieldnet.py | 6 +++++- torchgeo/samplers/batch.py | 7 ++++++- torchgeo/samplers/single.py | 20 +++++++++++++++++--- torchgeo/samplers/utils.py | 10 +++++++--- 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index bed6365d4a2..c5b92b6b01a 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None: if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( - self.train_dataset, self.patch_size, self.batch_size, self.length + self.train_dataset, + self.patch_size, + self.batch_size, + self.length, + generator=generator, ) if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 22726f74b2c..396ad0f0c7b 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -70,6 +70,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -97,9 +98,11 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: random number generator """ super().__init__(dataset, roi) self.size = _to_tuple(size) + self.generator = generator if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) @@ -144,7 +147,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]: # Choose random indices within that tile batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 094142cb647..ea943db3d53 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,6 +5,7 @@ import abc from collections.abc import Callable, Iterable, Iterator +from functools import partial import torch from rtree.index import Index, Property @@ -72,6 +73,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -98,6 +100,8 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: The random generator used for sampling. + """ super().__init__(dataset, roi) self.size = _to_tuple(size) @@ -105,6 +109,7 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.generator = generator self.length = 0 self.hits = [] areas = [] @@ -142,7 +147,9 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) yield bounding_box @@ -270,7 +277,11 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False + self, + dataset: GeoDataset, + roi: BoundingBox | None = None, + shuffle: bool = False, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -281,9 +292,12 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) shuffle: if True, reshuffle data at every epoch + generator: The random number generator used in combination with shuffle. + """ super().__init__(dataset, roi) self.shuffle = shuffle + self.generator = generator self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -297,7 +311,7 @@ def __iter__(self) -> Iterator[BoundingBox]: """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: - generator = torch.randperm + generator = partial(torch.randperm, generator=self.generator) for idx in generator(len(self)): yield BoundingBox(*self.hits[idx].bounds) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index a1fca673a3a..258f74a5425 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -35,7 +35,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: tuple[float, float] | float, res: float + bounds: BoundingBox, + size: tuple[float, float] | float, + res: float, + generator: torch.Generator | None = None, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -50,6 +53,7 @@ def get_random_bounding_box( bounds: the larger bounding box to sample from size: the size of the bounding box to sample res: the resolution of the image + generator: random number generator Returns: randomly sampled bounding box from the extent of the input @@ -64,8 +68,8 @@ def get_random_bounding_box( miny = bounds.miny # Use an integer multiple of res to avoid resampling - minx += int(torch.rand(1).item() * width) * res - miny += int(torch.rand(1).item() * height) * res + minx += int(torch.rand(1, generator=generator).item() * width) * res + miny += int(torch.rand(1, generator=generator).item() * height) * res maxx = minx + t_size[1] maxy = miny + t_size[0] From 5a9e107fd1b5556f177d9607e426e021ab57a75a Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Fri, 20 Sep 2024 20:21:19 +0200 Subject: [PATCH 2/7] Add tests for seed --- tests/samplers/test_batch.py | 16 ++++++++++++++++ tests/samplers/test_single.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 59c8aaa00be..20ad33a58c9 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -144,6 +145,21 @@ def test_weighted_sampling(self) -> None: for bbox in batch: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + generator = torch.manual_seed(0) + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample1 = bbox + break + + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample2 = bbox + break + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 1416368098a..15f1025f672 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -139,6 +140,21 @@ def test_weighted_sampling(self) -> None: for bbox in sampler: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + generator = torch.manual_seed(0) + sampler = RandomGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample1 = bbox + break + + sampler = RandomGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample2 = bbox + break + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -288,6 +304,22 @@ def test_point_data(self) -> None: for _ in sampler: continue + def test_shuffle_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (0, 11, 0, 11, 0, 11)) + generator = torch.manual_seed(0) + sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler: + sample1 = bbox + break + + sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler: + sample2 = bbox + break + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( From 46e1f11d440ecf1363393d7e616666cbd7f3e9f5 Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Fri, 20 Sep 2024 18:49:44 +0000 Subject: [PATCH 3/7] pass generator every sampler --- tests/samplers/test_batch.py | 5 ++--- tests/samplers/test_single.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 20ad33a58c9..16b99e16a93 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -148,13 +148,12 @@ def test_weighted_sampling(self) -> None: def test_random_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - generator = torch.manual_seed(0) - sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator) + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) for bbox in sampler: sample1 = bbox break - sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator) + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) for bbox in sampler: sample2 = bbox break diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 15f1025f672..abbf22d2727 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -308,15 +308,18 @@ def test_shuffle_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) ds.index.insert(1, (0, 11, 0, 11, 0, 11)) - generator = torch.manual_seed(0) - sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator) - for bbox in sampler: + generator = torch.manual_seed(2) + sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler1: sample1 = bbox + print(sample1) break - sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator) - for bbox in sampler: + generator = torch.manual_seed(2) + sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler2: sample2 = bbox + print(sample2) break assert sample1 == sample2 From 3d320e0ced569f433d632c82f5d27054ce6bc5ca Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Mon, 23 Sep 2024 13:41:24 +0200 Subject: [PATCH 4/7] Simplification of tests, docstring updates --- tests/samplers/test_batch.py | 13 ++++--------- tests/samplers/test_single.py | 14 ++++---------- torchgeo/samplers/batch.py | 8 ++++++-- torchgeo/samplers/single.py | 8 ++++++-- torchgeo/samplers/utils.py | 8 ++++++-- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 16b99e16a93..01dea100ca8 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -148,15 +148,10 @@ def test_weighted_sampling(self) -> None: def test_random_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) - for bbox in sampler: - sample1 = bbox - break - - sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) - for bbox in sampler: - sample2 = bbox - break + sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + sample1 = next(iter(sampler1)) + sample2 = next(iter(sampler2)) assert sample1 == sample2 @pytest.mark.slow diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index abbf22d2727..ef6e11e407b 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -143,16 +143,10 @@ def test_weighted_sampling(self) -> None: def test_random_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - generator = torch.manual_seed(0) - sampler = RandomGeoSampler(ds, 1, 1, generator=generator) - for bbox in sampler: - sample1 = bbox - break - - sampler = RandomGeoSampler(ds, 1, 1, generator=generator) - for bbox in sampler: - sample2 = bbox - break + sampler1 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + sampler2 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + sample1 = next(iter(sampler1)) + sample2 = next(iter(sampler2)) assert sample1 == sample2 @pytest.mark.slow diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 396ad0f0c7b..5d7147cd6b0 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -8,6 +8,7 @@ import torch from rtree.index import Index, Property +from torch import Generator from torch.utils.data import Sampler from ..datasets import BoundingBox, GeoDataset @@ -70,7 +71,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, - generator: torch.Generator | None = None, + generator: Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -87,6 +88,9 @@ def __init__( .. versionchanged:: 0.4 ``length`` parameter is now optional, a reasonable default will be used + .. versionadded:: 0.7 + The *generator* parameter. + Args: dataset: dataset to index from size: dimensions of each :term:`patch` @@ -98,7 +102,7 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units - generator: random number generator + generator: pseudo-random number generator (PRNG). """ super().__init__(dataset, roi) self.size = _to_tuple(size) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index ea943db3d53..5f7e461c13d 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -9,6 +9,7 @@ import torch from rtree.index import Index, Property +from torch import Generator from torch.utils.data import Sampler from ..datasets import BoundingBox, GeoDataset @@ -73,7 +74,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, - generator: torch.Generator | None = None, + generator: Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -90,6 +91,9 @@ def __init__( .. versionchanged:: 0.4 ``length`` parameter is now optional, a reasonable default will be used + .. versionadded:: 0.7 + The *generator* parameter. + Args: dataset: dataset to index from size: dimensions of each :term:`patch` @@ -100,7 +104,7 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units - generator: The random generator used for sampling. + generator: pseudo-random number generator (PRNG). """ super().__init__(dataset, roi) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 258f74a5425..5af4b6836e4 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -7,6 +7,7 @@ from typing import overload import torch +from torch import Generator from ..datasets import BoundingBox @@ -38,7 +39,7 @@ def get_random_bounding_box( bounds: BoundingBox, size: tuple[float, float] | float, res: float, - generator: torch.Generator | None = None, + generator: Generator | None = None, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -49,11 +50,14 @@ def get_random_bounding_box( * a ``tuple`` of two floats - in which case, the first *float* is used for the height dimension, and the second *float* for the width dimension + .. versionadded:: 0.7 + The *generator* parameter. + Args: bounds: the larger bounding box to sample from size: the size of the bounding box to sample res: the resolution of the image - generator: random number generator + generator: pseudo-random number generator (PRNG). Returns: randomly sampled bounding box from the extent of the input From a80b66d6b745f0e6da1e103504d8969afd7b1fe8 Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Mon, 23 Sep 2024 14:11:36 +0200 Subject: [PATCH 5/7] try to pass docs build --- torchgeo/samplers/batch.py | 2 +- torchgeo/samplers/single.py | 3 +-- torchgeo/samplers/utils.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 5d7147cd6b0..686b458ce24 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -89,7 +89,7 @@ def __init__( ``length`` parameter is now optional, a reasonable default will be used .. versionadded:: 0.7 - The *generator* parameter. + The *generator* parameter. Args: dataset: dataset to index from diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 5f7e461c13d..b81597a54a5 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -92,7 +92,7 @@ def __init__( ``length`` parameter is now optional, a reasonable default will be used .. versionadded:: 0.7 - The *generator* parameter. + The *generator* parameter. Args: dataset: dataset to index from @@ -105,7 +105,6 @@ def __init__( (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units generator: pseudo-random number generator (PRNG). - """ super().__init__(dataset, roi) self.size = _to_tuple(size) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 5af4b6836e4..48ad760f928 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -51,7 +51,7 @@ def get_random_bounding_box( height dimension, and the second *float* for the width dimension .. versionadded:: 0.7 - The *generator* parameter. + The *generator* parameter. Args: bounds: the larger bounding box to sample from From 1ac686060a9afa04b7e0b27e8959b8737ec3801c Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Mon, 23 Sep 2024 14:26:52 +0200 Subject: [PATCH 6/7] forgotten updates --- tests/samplers/test_single.py | 21 ++++++++------------- torchgeo/samplers/single.py | 5 ++++- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index ef6e11e407b..839466de78e 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -302,19 +302,14 @@ def test_shuffle_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) ds.index.insert(1, (0, 11, 0, 11, 0, 11)) - generator = torch.manual_seed(2) - sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) - for bbox in sampler1: - sample1 = bbox - print(sample1) - break - - generator = torch.manual_seed(2) - sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) - for bbox in sampler2: - sample2 = bbox - print(sample2) - break + sampler1 = PreChippedGeoSampler( + ds, shuffle=True, generator=torch.manual_seed(2) + ) + sampler2 = PreChippedGeoSampler( + ds, shuffle=True, generator=torch.manual_seed(2) + ) + sample1 = next(iter(sampler1)) + sample2 = next(iter(sampler2)) assert sample1 == sample2 @pytest.mark.slow diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index b81597a54a5..6fa4331c4b7 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -290,12 +290,15 @@ def __init__( .. versionadded:: 0.3 + .. versionadded:: 0.7 + The *generator* parameter. + Args: dataset: dataset to index from roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) shuffle: if True, reshuffle data at every epoch - generator: The random number generator used in combination with shuffle. + generator: pseudo-random number generator (PRNG) used in combination with shuffle. """ super().__init__(dataset, roi) From 40babfc6435a649ce28d51175f8912a99d3dbb47 Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Mon, 23 Sep 2024 14:39:52 +0200 Subject: [PATCH 7/7] equal should have been unequal --- tests/samplers/test_single.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 839466de78e..743c8be70da 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -310,7 +310,7 @@ def test_shuffle_seed(self) -> None: ) sample1 = next(iter(sampler1)) sample2 = next(iter(sampler2)) - assert sample1 == sample2 + assert sample1 != sample2 @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2])