Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridGeoSampler: change stride of last patch to sample entire ROI #630

Merged
merged 23 commits into from
Sep 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ad093fb
Adjust minx/miny with a smaller stride for the last sample per row/co…
remtav Mar 2, 2022
855fee1
style and mypy fixes
remtav Mar 14, 2022
eb33fe0
black test fix
remtav Mar 14, 2022
cb554c6
Adjust minx/miny with a smaller stride for the last sample per row/co…
remtav Mar 2, 2022
1a0236c
style and mypy fixes
remtav Mar 14, 2022
bfadf76
black test fix
remtav Mar 14, 2022
6942b5f
single.py: adapt gridgeosampler to sample beyond limit of ROI for a p…
remtav Jun 28, 2022
138855b
format for black and flake8
remtav Jun 28, 2022
37d6055
format for black and flake8
remtav Jun 28, 2022
b0ab3fa
once again, format for black and flake8
remtav Jun 28, 2022
45b3490
Merge branch 'microsoft:main' into samplers/gridgeosampler_bounds
remtav Aug 25, 2022
720cf5b
Merge remote-tracking branch 'origin/samplers/gridgeosampler_bounds' …
remtav Aug 25, 2022
af5a3d1
Revert "Adjust minx/miny with a smaller stride for the last sample pe…
remtav Aug 29, 2022
e588385
Merge branch 'microsoft:main' into samplers/gridgeosampler_bounds
remtav Aug 29, 2022
0e61b1d
adapt unit tests, remove warnings
remtav Aug 29, 2022
6c623d8
flake8: remove warnings import
remtav Aug 30, 2022
fd9b69a
Merge branch 'main' into samplers/gridgeosampler_bounds
remtav Aug 30, 2022
13daca1
Merge branch 'main' into samplers/gridgeosampler_bounds
adamjstewart Sep 3, 2022
9d68d1e
Address some comments
adamjstewart Sep 3, 2022
8660f28
Simplify computation of # rows/cols
adamjstewart Sep 3, 2022
9536970
Document this new feature
adamjstewart Sep 3, 2022
32d877e
Fix size of ceiling symbol
adamjstewart Sep 3, 2022
a741129
Simplify tests
adamjstewart Sep 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def test_iter(self, sampler: GridGeoSampler) -> None:
)

def test_len(self, sampler: GridGeoSampler) -> None:
rows = ((100 - sampler.size[0]) // sampler.stride[0]) + 1
cols = ((100 - sampler.size[1]) // sampler.stride[1]) + 1
length = rows * cols * 2
rows = math.ceil((100 - sampler.size[0]) / sampler.stride[0]) + 1
cols = math.ceil((100 - sampler.size[1]) / sampler.stride[1]) + 1
length = rows * cols * 2 # two items in dataset
assert len(sampler) == length

def test_roi(self, dataset: CustomGeoDataset) -> None:
Expand All @@ -194,12 +194,35 @@ def test_roi(self, dataset: CustomGeoDataset) -> None:
assert query in roi

def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 1, 0, 1, 0, 1))
sampler = GridGeoSampler(ds, 2, 10)
assert len(sampler) == 0

def test_tiles_side_by_side(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
ds.index.insert(0, (0, 10, 10, 20, 0, 10))
sampler = GridGeoSampler(ds, 2, 10)
for _ in sampler:
continue
for bbox in sampler:
assert bbox.area > 0

def test_integer_multiple(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS)
iterator = iter(sampler)
assert len(sampler) == 1
assert next(iterator) == BoundingBox(0, 10, 0, 10, 0, 10)

def test_float_multiple(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 6, 0, 5, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
iterator = iter(sampler)
assert len(sampler) == 2
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
assert next(iterator) == BoundingBox(1, 6, 0, 5, 0, 10)

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
Expand Down
48 changes: 41 additions & 7 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""TorchGeo samplers."""

import abc
import math
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -146,7 +147,7 @@ def __len__(self) -> int:


class GridGeoSampler(GeoSampler):
"""Samples elements in a grid-like fashion.
r"""Samples elements in a grid-like fashion.

This is particularly useful during evaluation when you want to make predictions for
an entire region of interest. You want to minimize the amount of redundant
Expand All @@ -158,6 +159,21 @@ class GridGeoSampler(GeoSampler):
The overlap between each chip (``chip_size - stride``) should be approximately equal
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
the CNN.

Note that the stride of the final set of chips in each row/column may be adjusted so
that the entire :term:`tile` is sampled without exceeding the bounds of the dataset.

Let :math:`i` be the size of the input tile. Let :math:`k` be the requested size of
the output patch. Let :math:`s` be the requested stride. Let :math:`o` be the number
of output rows/columns sampled from each tile. :math:`o` can then be computed as:

.. math::

o = \left\lceil \frac{i - k}{s} \right\rceil + 1

This is almost identical to relationship 5 in
https://doi.org/10.48550/arXiv.1603.07285. However, we use ceiling instead of floor
because we want to include the final remaining chip.
"""

def __init__(
Expand Down Expand Up @@ -200,17 +216,23 @@ def __init__(
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx > self.size[1]
and bounds.maxy - bounds.miny > self.size[0]
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
self.hits.append(hit)

self.length = 0
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = (
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
+ 1
)
cols = (
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
+ 1
)
self.length += rows * cols

def __iter__(self) -> Iterator[BoundingBox]:
Expand All @@ -223,8 +245,14 @@ def __iter__(self) -> Iterator[BoundingBox]:
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = (
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
+ 1
)
cols = (
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
+ 1
)

mint = bounds.mint
maxt = bounds.maxt
Expand All @@ -233,11 +261,17 @@ def __iter__(self) -> Iterator[BoundingBox]:
for i in range(rows):
miny = bounds.miny + i * self.stride[0]
maxy = miny + self.size[0]
if maxy > bounds.maxy:
maxy = bounds.maxy
miny = bounds.maxy - self.size[0]

# For each column...
for j in range(cols):
minx = bounds.minx + j * self.stride[1]
maxx = minx + self.size[1]
if maxx > bounds.maxx:
maxx = bounds.maxx
minx = bounds.maxx - self.size[1]

yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)

Expand Down