From 7fa0fd429e5156e3f71862253ba1124b0244b260 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Tavon?= <34774759+remtav@users.noreply.github.com> Date: Sat, 3 Sep 2022 00:11:14 -0400 Subject: [PATCH] GridGeoSampler: change stride of last patch to sample entire ROI (#630) * Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning * style and mypy fixes * black test fix * Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning * style and mypy fixes * black test fix * single.py: adapt gridgeosampler to sample beyond limit of ROI for a partial patch (to be padded) test_single.py: add tests for multiple limit cases (see issue #448) * format for black and flake8 * format for black and flake8 * once again, format for black and flake8 * Revert "Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning" This reverts commit cb554c67 * adapt unit tests, remove warnings * flake8: remove warnings import * Address some comments * Simplify computation of # rows/cols * Document this new feature * Fix size of ceiling symbol * Simplify tests Co-authored-by: Adam J. Stewart --- tests/samplers/test_single.py | 35 ++++++++++++++++++++----- torchgeo/samplers/single.py | 48 ++++++++++++++++++++++++++++++----- 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 556de8c6587..0528dfe78c4 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -182,9 +182,9 @@ def test_iter(self, sampler: GridGeoSampler) -> None: ) def test_len(self, sampler: GridGeoSampler) -> None: - rows = ((100 - sampler.size[0]) // sampler.stride[0]) + 1 - cols = ((100 - sampler.size[1]) // sampler.stride[1]) + 1 - length = rows * cols * 2 + rows = math.ceil((100 - sampler.size[0]) / sampler.stride[0]) + 1 + cols = math.ceil((100 - sampler.size[1]) / sampler.stride[1]) + 1 + length = rows * cols * 2 # two items in dataset assert len(sampler) == length def test_roi(self, dataset: CustomGeoDataset) -> None: @@ -194,12 +194,35 @@ def test_roi(self, dataset: CustomGeoDataset) -> None: assert query in roi def test_small_area(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 1, 0, 1, 0, 1)) + sampler = GridGeoSampler(ds, 2, 10) + assert len(sampler) == 0 + + def test_tiles_side_by_side(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (20, 21, 20, 21, 20, 21)) + ds.index.insert(0, (0, 10, 10, 20, 0, 10)) sampler = GridGeoSampler(ds, 2, 10) - for _ in sampler: - continue + for bbox in sampler: + assert bbox.area > 0 + + def test_integer_multiple(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS) + iterator = iter(sampler) + assert len(sampler) == 1 + assert next(iterator) == BoundingBox(0, 10, 0, 10, 0, 10) + + def test_float_multiple(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 6, 0, 5, 0, 10)) + sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + iterator = iter(sampler) + assert len(sampler) == 2 + assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) + assert next(iterator) == BoundingBox(1, 6, 0, 5, 0, 10) @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 92930a24382..e063d9ecbd4 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,6 +4,7 @@ """TorchGeo samplers.""" import abc +import math from typing import Callable, Iterable, Iterator, Optional, Tuple, Union import torch @@ -146,7 +147,7 @@ def __len__(self) -> int: class GridGeoSampler(GeoSampler): - """Samples elements in a grid-like fashion. + r"""Samples elements in a grid-like fashion. This is particularly useful during evaluation when you want to make predictions for an entire region of interest. You want to minimize the amount of redundant @@ -158,6 +159,21 @@ class GridGeoSampler(GeoSampler): The overlap between each chip (``chip_size - stride``) should be approximately equal to the `receptive field `_ of the CNN. + + Note that the stride of the final set of chips in each row/column may be adjusted so + that the entire :term:`tile` is sampled without exceeding the bounds of the dataset. + + Let :math:`i` be the size of the input tile. Let :math:`k` be the requested size of + the output patch. Let :math:`s` be the requested stride. Let :math:`o` be the number + of output rows/columns sampled from each tile. :math:`o` can then be computed as: + + .. math:: + + o = \left\lceil \frac{i - k}{s} \right\rceil + 1 + + This is almost identical to relationship 5 in + https://doi.org/10.48550/arXiv.1603.07285. However, we use ceiling instead of floor + because we want to include the final remaining chip. """ def __init__( @@ -200,8 +216,8 @@ def __init__( for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) if ( - bounds.maxx - bounds.minx > self.size[1] - and bounds.maxy - bounds.miny > self.size[0] + bounds.maxx - bounds.minx >= self.size[1] + and bounds.maxy - bounds.miny >= self.size[0] ): self.hits.append(hit) @@ -209,8 +225,14 @@ def __init__( for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 + rows = ( + math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0]) + + 1 + ) + cols = ( + math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1]) + + 1 + ) self.length += rows * cols def __iter__(self) -> Iterator[BoundingBox]: @@ -223,8 +245,14 @@ def __iter__(self) -> Iterator[BoundingBox]: for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 + rows = ( + math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0]) + + 1 + ) + cols = ( + math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1]) + + 1 + ) mint = bounds.mint maxt = bounds.maxt @@ -233,11 +261,17 @@ def __iter__(self) -> Iterator[BoundingBox]: for i in range(rows): miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] + if maxy > bounds.maxy: + maxy = bounds.maxy + miny = bounds.maxy - self.size[0] # For each column... for j in range(cols): minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] + if maxx > bounds.maxx: + maxx = bounds.maxx + minx = bounds.maxx - self.size[1] yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)