Skip to content

Commit

Permalink
GridGeoSampler: change stride of last patch to sample entire ROI (mic…
Browse files Browse the repository at this point in the history
…rosoft#630)

* Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning

* style and mypy fixes

* black test fix

* Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning

* style and mypy fixes

* black test fix

* single.py: adapt gridgeosampler to sample beyond limit of ROI for a partial patch (to be padded)
test_single.py: add tests for multiple limit cases (see issue microsoft#448)

* format for black and flake8

* format for black and flake8

* once again, format for black and flake8

* Revert "Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning"

This reverts commit cb554c6

* adapt unit tests, remove warnings

* flake8: remove warnings import

* Address some comments

* Simplify computation of # rows/cols

* Document this new feature

* Fix size of ceiling symbol

* Simplify tests

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
2 people authored and Modexus committed Sep 3, 2022
1 parent cc33d9f commit e388276
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 14 deletions.
35 changes: 29 additions & 6 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def test_iter(self, sampler: GridGeoSampler) -> None:
)

def test_len(self, sampler: GridGeoSampler) -> None:
rows = ((100 - sampler.size[0]) // sampler.stride[0]) + 1
cols = ((100 - sampler.size[1]) // sampler.stride[1]) + 1
length = rows * cols * 2
rows = math.ceil((100 - sampler.size[0]) / sampler.stride[0]) + 1
cols = math.ceil((100 - sampler.size[1]) / sampler.stride[1]) + 1
length = rows * cols * 2 # two items in dataset
assert len(sampler) == length

def test_roi(self, dataset: CustomGeoDataset) -> None:
Expand All @@ -194,12 +194,35 @@ def test_roi(self, dataset: CustomGeoDataset) -> None:
assert query in roi

def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 1, 0, 1, 0, 1))
sampler = GridGeoSampler(ds, 2, 10)
assert len(sampler) == 0

def test_tiles_side_by_side(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
ds.index.insert(0, (0, 10, 10, 20, 0, 10))
sampler = GridGeoSampler(ds, 2, 10)
for _ in sampler:
continue
for bbox in sampler:
assert bbox.area > 0

def test_integer_multiple(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS)
iterator = iter(sampler)
assert len(sampler) == 1
assert next(iterator) == BoundingBox(0, 10, 0, 10, 0, 10)

def test_float_multiple(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 6, 0, 5, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
iterator = iter(sampler)
assert len(sampler) == 2
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
assert next(iterator) == BoundingBox(1, 6, 0, 5, 0, 10)

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
Expand Down
50 changes: 42 additions & 8 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""TorchGeo samplers."""

import abc
from typing import Any, Callable, Iterable, Iterator, Optional, Sequence, Tuple, Union
import math
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
from rtree.index import Index, Property
Expand Down Expand Up @@ -164,7 +165,7 @@ def __len__(self) -> int:


class GridGeoSampler(GeoSampler):
"""Samples elements in a grid-like fashion.
r"""Samples elements in a grid-like fashion.
This is particularly useful during evaluation when you want to make predictions for
an entire region of interest. You want to minimize the amount of redundant
Expand All @@ -176,6 +177,21 @@ class GridGeoSampler(GeoSampler):
The overlap between each chip (``chip_size - stride``) should be approximately equal
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
the CNN.
Note that the stride of the final set of chips in each row/column may be adjusted so
that the entire :term:`tile` is sampled without exceeding the bounds of the dataset.
Let :math:`i` be the size of the input tile. Let :math:`k` be the requested size of
the output patch. Let :math:`s` be the requested stride. Let :math:`o` be the number
of output rows/columns sampled from each tile. :math:`o` can then be computed as:
.. math::
o = \left\lceil \frac{i - k}{s} \right\rceil + 1
This is almost identical to relationship 5 in
https://doi.org/10.48550/arXiv.1603.07285. However, we use ceiling instead of floor
because we want to include the final remaining chip.
"""

def __init__(
Expand Down Expand Up @@ -219,17 +235,23 @@ def __init__(
for hit in self.index.intersection(tuple(dataset.bounds), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx > self.size[1]
and bounds.maxy - bounds.miny > self.size[0]
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
self.hits.append(hit)

self.length = 0
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = (
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
+ 1
)
cols = (
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
+ 1
)
self.length += rows * cols

def __iter__(self) -> Iterator[BoundingBox]:
Expand All @@ -242,8 +264,14 @@ def __iter__(self) -> Iterator[BoundingBox]:
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = (
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
+ 1
)
cols = (
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
+ 1
)

mint = bounds.mint
maxt = bounds.maxt
Expand All @@ -252,11 +280,17 @@ def __iter__(self) -> Iterator[BoundingBox]:
for i in range(rows):
miny = bounds.miny + i * self.stride[0]
maxy = miny + self.size[0]
if maxy > bounds.maxy:
maxy = bounds.maxy
miny = bounds.maxy - self.size[0]

# For each column...
for j in range(cols):
minx = bounds.minx + j * self.stride[1]
maxx = minx + self.size[1]
if maxx > bounds.maxx:
maxx = bounds.maxx
minx = bounds.maxx - self.size[1]

yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)

Expand Down

0 comments on commit e388276

Please sign in to comment.