Skip to content

Commit

Permalink
GridGeoSampler: change stride of last patch to sample entire ROI (#630)
Browse files Browse the repository at this point in the history
* Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning

* style and mypy fixes

* black test fix

* Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning

* style and mypy fixes

* black test fix

* single.py: adapt gridgeosampler to sample beyond limit of ROI for a partial patch (to be padded)
test_single.py: add tests for multiple limit cases (see issue #448)

* format for black and flake8

* format for black and flake8

* once again, format for black and flake8

* Revert "Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning"

This reverts commit cb554c6

* adapt unit tests, remove warnings

* flake8: remove warnings import

* Address some comments

* Simplify computation of # rows/cols

* Document this new feature

* Fix size of ceiling symbol

* Simplify tests

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
remtav and adamjstewart authored Sep 3, 2022
1 parent f41619a commit 7fa0fd4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 13 deletions.
35 changes: 29 additions & 6 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def test_iter(self, sampler: GridGeoSampler) -> None:
)

def test_len(self, sampler: GridGeoSampler) -> None:
rows = ((100 - sampler.size[0]) // sampler.stride[0]) + 1
cols = ((100 - sampler.size[1]) // sampler.stride[1]) + 1
length = rows * cols * 2
rows = math.ceil((100 - sampler.size[0]) / sampler.stride[0]) + 1
cols = math.ceil((100 - sampler.size[1]) / sampler.stride[1]) + 1
length = rows * cols * 2 # two items in dataset
assert len(sampler) == length

def test_roi(self, dataset: CustomGeoDataset) -> None:
Expand All @@ -194,12 +194,35 @@ def test_roi(self, dataset: CustomGeoDataset) -> None:
assert query in roi

def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 1, 0, 1, 0, 1))
sampler = GridGeoSampler(ds, 2, 10)
assert len(sampler) == 0

def test_tiles_side_by_side(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
ds.index.insert(0, (0, 10, 10, 20, 0, 10))
sampler = GridGeoSampler(ds, 2, 10)
for _ in sampler:
continue
for bbox in sampler:
assert bbox.area > 0

def test_integer_multiple(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS)
iterator = iter(sampler)
assert len(sampler) == 1
assert next(iterator) == BoundingBox(0, 10, 0, 10, 0, 10)

def test_float_multiple(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 6, 0, 5, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
iterator = iter(sampler)
assert len(sampler) == 2
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
assert next(iterator) == BoundingBox(1, 6, 0, 5, 0, 10)

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
Expand Down
48 changes: 41 additions & 7 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""TorchGeo samplers."""

import abc
import math
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -146,7 +147,7 @@ def __len__(self) -> int:


class GridGeoSampler(GeoSampler):
"""Samples elements in a grid-like fashion.
r"""Samples elements in a grid-like fashion.
This is particularly useful during evaluation when you want to make predictions for
an entire region of interest. You want to minimize the amount of redundant
Expand All @@ -158,6 +159,21 @@ class GridGeoSampler(GeoSampler):
The overlap between each chip (``chip_size - stride``) should be approximately equal
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
the CNN.
Note that the stride of the final set of chips in each row/column may be adjusted so
that the entire :term:`tile` is sampled without exceeding the bounds of the dataset.
Let :math:`i` be the size of the input tile. Let :math:`k` be the requested size of
the output patch. Let :math:`s` be the requested stride. Let :math:`o` be the number
of output rows/columns sampled from each tile. :math:`o` can then be computed as:
.. math::
o = \left\lceil \frac{i - k}{s} \right\rceil + 1
This is almost identical to relationship 5 in
https://doi.org/10.48550/arXiv.1603.07285. However, we use ceiling instead of floor
because we want to include the final remaining chip.
"""

def __init__(
Expand Down Expand Up @@ -200,17 +216,23 @@ def __init__(
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx > self.size[1]
and bounds.maxy - bounds.miny > self.size[0]
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
self.hits.append(hit)

self.length = 0
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = (
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
+ 1
)
cols = (
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
+ 1
)
self.length += rows * cols

def __iter__(self) -> Iterator[BoundingBox]:
Expand All @@ -223,8 +245,14 @@ def __iter__(self) -> Iterator[BoundingBox]:
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = (
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
+ 1
)
cols = (
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
+ 1
)

mint = bounds.mint
maxt = bounds.maxt
Expand All @@ -233,11 +261,17 @@ def __iter__(self) -> Iterator[BoundingBox]:
for i in range(rows):
miny = bounds.miny + i * self.stride[0]
maxy = miny + self.size[0]
if maxy > bounds.maxy:
maxy = bounds.maxy
miny = bounds.maxy - self.size[0]

# For each column...
for j in range(cols):
minx = bounds.minx + j * self.stride[1]
maxx = minx + self.size[1]
if maxx > bounds.maxx:
maxx = bounds.maxx
minx = bounds.maxx - self.size[1]

yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)

Expand Down

0 comments on commit 7fa0fd4

Please sign in to comment.