From ad093fbc5dbec3843f20249c7640ec5e8eb78bde Mon Sep 17 00:00:00 2001 From: remtav Date: Wed, 2 Mar 2022 11:28:21 -0500 Subject: [PATCH 01/18] Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning --- torchgeo/samplers/single.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 781cb5c38b1..b90ac50d72e 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,6 +5,7 @@ import abc import random +import warnings from typing import Iterator, Optional, Tuple, Union from rtree.index import Index, Property @@ -180,7 +181,7 @@ def __init__( .. versionchanged:: 0.3 Added ``units`` parameter, changed default to pixel units """ - super().__init__(dataset, roi) + super().__init__(dataset=dataset, roi=roi, stride=stride, size=size) self.size = _to_tuple(size) self.stride = _to_tuple(stride) @@ -201,8 +202,8 @@ def __init__( for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 + rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1 + cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1 self.length += rows * cols def __iter__(self) -> Iterator[BoundingBox]: @@ -215,8 +216,8 @@ def __iter__(self) -> Iterator[BoundingBox]: for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 + rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1 + cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1 mint = bounds.mint maxt = bounds.maxt @@ -225,11 +226,29 @@ def __iter__(self) -> Iterator[BoundingBox]: for i in range(rows): miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] + if maxy > bounds.maxy: + last_stride_y = self.stride[0] - (miny - (bounds.maxy - self.size[0])) + maxy = bounds.maxy + miny = bounds.maxy - self.size[0] + warnings.warn( + f"Max y coordinate of bounding box reaches passed y bounds of source tile" + f"Bounding box will be moved to set max y at source tile's max y. Stride will be adjusted" + f"from {self.stride[0]:.2f} to {last_stride_y:.2f}" + ) # For each column... for j in range(cols): minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] + if maxx > bounds.maxx: + last_stride_x = self.stride[1] - (minx - (bounds.maxx - self.size[1])) + maxx = bounds.maxx + minx = bounds.maxx - self.size[1] + warnings.warn( + f"Max x coordinate of bounding box reaches passed x bounds of source tile" + f"Bounding box will be moved to set max x at source tile's max x. Stride will be adjusted" + f"from {self.stride[1]:.2f} to {last_stride_x:.2f}" + ) yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) From 855fee1b7456a8e9734a495a05407e1a8644adb5 Mon Sep 17 00:00:00 2001 From: remtav Date: Mon, 14 Mar 2022 15:41:52 -0400 Subject: [PATCH 02/18] style and mypy fixes --- torchgeo/samplers/single.py | 54 +++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index b90ac50d72e..3fb1d82713e 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -181,7 +181,7 @@ def __init__( .. versionchanged:: 0.3 Added ``units`` parameter, changed default to pixel units """ - super().__init__(dataset=dataset, roi=roi, stride=stride, size=size) + super().__init__(dataset=dataset, roi=roi) self.size = _to_tuple(size) self.stride = _to_tuple(stride) @@ -202,8 +202,20 @@ def __init__( for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1 + rows = ( + int( + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) + // self.stride[0] + ) + + 1 + ) + cols = ( + int( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + // self.stride[1] + ) + + 1 + ) self.length += rows * cols def __iter__(self) -> Iterator[BoundingBox]: @@ -216,8 +228,20 @@ def __iter__(self) -> Iterator[BoundingBox]: for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1 + rows = ( + int( + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) + // self.stride[0] + ) + + 1 + ) + cols = ( + int( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + // self.stride[1] + ) + + 1 + ) mint = bounds.mint maxt = bounds.maxt @@ -227,12 +251,16 @@ def __iter__(self) -> Iterator[BoundingBox]: miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] if maxy > bounds.maxy: - last_stride_y = self.stride[0] - (miny - (bounds.maxy - self.size[0])) + last_stride_y = self.stride[0] - ( + miny - (bounds.maxy - self.size[0]) + ) maxy = bounds.maxy miny = bounds.maxy - self.size[0] warnings.warn( - f"Max y coordinate of bounding box reaches passed y bounds of source tile" - f"Bounding box will be moved to set max y at source tile's max y. Stride will be adjusted" + f"Max y coordinate of bounding box reaches passed y bounds of " + f"source tile" + f"Bounding box will be moved to set max y at source tile's max" + f" y. Stride will be adjusted" f"from {self.stride[0]:.2f} to {last_stride_y:.2f}" ) @@ -241,12 +269,16 @@ def __iter__(self) -> Iterator[BoundingBox]: minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] if maxx > bounds.maxx: - last_stride_x = self.stride[1] - (minx - (bounds.maxx - self.size[1])) + last_stride_x = self.stride[1] - ( + minx - (bounds.maxx - self.size[1]) + ) maxx = bounds.maxx minx = bounds.maxx - self.size[1] warnings.warn( - f"Max x coordinate of bounding box reaches passed x bounds of source tile" - f"Bounding box will be moved to set max x at source tile's max x. Stride will be adjusted" + f"Max x coordinate of bounding box reaches passed x bounds" + f" of source tile" + f"Bounding box will be moved to set max x at source tile's" + f" max x. Stride will be adjusted" f"from {self.stride[1]:.2f} to {last_stride_x:.2f}" ) From eb33fe001e478216ef114a33ecc9d4e813bed9a3 Mon Sep 17 00:00:00 2001 From: remtav Date: Mon, 14 Mar 2022 16:17:44 -0400 Subject: [PATCH 03/18] black test fix --- torchgeo/samplers/single.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 3fb1d82713e..eed8b820286 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -210,11 +210,11 @@ def __init__( + 1 ) cols = ( - int( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) - // self.stride[1] - ) - + 1 + int( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + // self.stride[1] + ) + + 1 ) self.length += rows * cols @@ -229,18 +229,18 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) rows = ( - int( - (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) - // self.stride[0] - ) - + 1 + int( + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) + // self.stride[0] + ) + + 1 ) cols = ( - int( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) - // self.stride[1] - ) - + 1 + int( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + // self.stride[1] + ) + + 1 ) mint = bounds.mint @@ -252,7 +252,7 @@ def __iter__(self) -> Iterator[BoundingBox]: maxy = miny + self.size[0] if maxy > bounds.maxy: last_stride_y = self.stride[0] - ( - miny - (bounds.maxy - self.size[0]) + miny - (bounds.maxy - self.size[0]) ) maxy = bounds.maxy miny = bounds.maxy - self.size[0] @@ -270,7 +270,7 @@ def __iter__(self) -> Iterator[BoundingBox]: maxx = minx + self.size[1] if maxx > bounds.maxx: last_stride_x = self.stride[1] - ( - minx - (bounds.maxx - self.size[1]) + minx - (bounds.maxx - self.size[1]) ) maxx = bounds.maxx minx = bounds.maxx - self.size[1] From cb554c67386ae252a26028cbe7903b94f58b1bd6 Mon Sep 17 00:00:00 2001 From: remtav Date: Wed, 2 Mar 2022 11:28:21 -0500 Subject: [PATCH 04/18] Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning --- torchgeo/samplers/single.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index e31d13bdd44..14dd6b49e23 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,6 +4,8 @@ """TorchGeo samplers.""" import abc +import random +import warnings from typing import Callable, Iterable, Iterator, Optional, Tuple, Union import torch @@ -188,7 +190,7 @@ def __init__( .. versionchanged:: 0.3 Added ``units`` parameter, changed default to pixel units """ - super().__init__(dataset, roi) + super().__init__(dataset=dataset, roi=roi, stride=stride, size=size) self.size = _to_tuple(size) self.stride = _to_tuple(stride) @@ -209,8 +211,8 @@ def __init__( for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 + rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1 + cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1 self.length += rows * cols def __iter__(self) -> Iterator[BoundingBox]: @@ -223,8 +225,8 @@ def __iter__(self) -> Iterator[BoundingBox]: for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 + rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1 + cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1 mint = bounds.mint maxt = bounds.maxt @@ -233,11 +235,29 @@ def __iter__(self) -> Iterator[BoundingBox]: for i in range(rows): miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] + if maxy > bounds.maxy: + last_stride_y = self.stride[0] - (miny - (bounds.maxy - self.size[0])) + maxy = bounds.maxy + miny = bounds.maxy - self.size[0] + warnings.warn( + f"Max y coordinate of bounding box reaches passed y bounds of source tile" + f"Bounding box will be moved to set max y at source tile's max y. Stride will be adjusted" + f"from {self.stride[0]:.2f} to {last_stride_y:.2f}" + ) # For each column... for j in range(cols): minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] + if maxx > bounds.maxx: + last_stride_x = self.stride[1] - (minx - (bounds.maxx - self.size[1])) + maxx = bounds.maxx + minx = bounds.maxx - self.size[1] + warnings.warn( + f"Max x coordinate of bounding box reaches passed x bounds of source tile" + f"Bounding box will be moved to set max x at source tile's max x. Stride will be adjusted" + f"from {self.stride[1]:.2f} to {last_stride_x:.2f}" + ) yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) From 1a0236c83521aa5864cc0668b66cff263677c3dd Mon Sep 17 00:00:00 2001 From: remtav Date: Mon, 14 Mar 2022 15:41:52 -0400 Subject: [PATCH 05/18] style and mypy fixes --- torchgeo/samplers/single.py | 54 +++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 14dd6b49e23..63674b8cc85 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -190,7 +190,7 @@ def __init__( .. versionchanged:: 0.3 Added ``units`` parameter, changed default to pixel units """ - super().__init__(dataset=dataset, roi=roi, stride=stride, size=size) + super().__init__(dataset=dataset, roi=roi) self.size = _to_tuple(size) self.stride = _to_tuple(stride) @@ -211,8 +211,20 @@ def __init__( for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1 + rows = ( + int( + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) + // self.stride[0] + ) + + 1 + ) + cols = ( + int( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + // self.stride[1] + ) + + 1 + ) self.length += rows * cols def __iter__(self) -> Iterator[BoundingBox]: @@ -225,8 +237,20 @@ def __iter__(self) -> Iterator[BoundingBox]: for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1 - cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1 + rows = ( + int( + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) + // self.stride[0] + ) + + 1 + ) + cols = ( + int( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + // self.stride[1] + ) + + 1 + ) mint = bounds.mint maxt = bounds.maxt @@ -236,12 +260,16 @@ def __iter__(self) -> Iterator[BoundingBox]: miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] if maxy > bounds.maxy: - last_stride_y = self.stride[0] - (miny - (bounds.maxy - self.size[0])) + last_stride_y = self.stride[0] - ( + miny - (bounds.maxy - self.size[0]) + ) maxy = bounds.maxy miny = bounds.maxy - self.size[0] warnings.warn( - f"Max y coordinate of bounding box reaches passed y bounds of source tile" - f"Bounding box will be moved to set max y at source tile's max y. Stride will be adjusted" + f"Max y coordinate of bounding box reaches passed y bounds of " + f"source tile" + f"Bounding box will be moved to set max y at source tile's max" + f" y. Stride will be adjusted" f"from {self.stride[0]:.2f} to {last_stride_y:.2f}" ) @@ -250,12 +278,16 @@ def __iter__(self) -> Iterator[BoundingBox]: minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] if maxx > bounds.maxx: - last_stride_x = self.stride[1] - (minx - (bounds.maxx - self.size[1])) + last_stride_x = self.stride[1] - ( + minx - (bounds.maxx - self.size[1]) + ) maxx = bounds.maxx minx = bounds.maxx - self.size[1] warnings.warn( - f"Max x coordinate of bounding box reaches passed x bounds of source tile" - f"Bounding box will be moved to set max x at source tile's max x. Stride will be adjusted" + f"Max x coordinate of bounding box reaches passed x bounds" + f" of source tile" + f"Bounding box will be moved to set max x at source tile's" + f" max x. Stride will be adjusted" f"from {self.stride[1]:.2f} to {last_stride_x:.2f}" ) From bfadf762018fed8c6b4ce66e48ca7a78cd201f8f Mon Sep 17 00:00:00 2001 From: remtav Date: Mon, 14 Mar 2022 16:17:44 -0400 Subject: [PATCH 06/18] black test fix --- torchgeo/samplers/single.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 63674b8cc85..a18f06bb97b 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -219,11 +219,11 @@ def __init__( + 1 ) cols = ( - int( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) - // self.stride[1] - ) - + 1 + int( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + // self.stride[1] + ) + + 1 ) self.length += rows * cols @@ -238,18 +238,18 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) rows = ( - int( - (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) - // self.stride[0] - ) - + 1 + int( + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) + // self.stride[0] + ) + + 1 ) cols = ( - int( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) - // self.stride[1] - ) - + 1 + int( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + // self.stride[1] + ) + + 1 ) mint = bounds.mint @@ -261,7 +261,7 @@ def __iter__(self) -> Iterator[BoundingBox]: maxy = miny + self.size[0] if maxy > bounds.maxy: last_stride_y = self.stride[0] - ( - miny - (bounds.maxy - self.size[0]) + miny - (bounds.maxy - self.size[0]) ) maxy = bounds.maxy miny = bounds.maxy - self.size[0] @@ -279,7 +279,7 @@ def __iter__(self) -> Iterator[BoundingBox]: maxx = minx + self.size[1] if maxx > bounds.maxx: last_stride_x = self.stride[1] - ( - minx - (bounds.maxx - self.size[1]) + minx - (bounds.maxx - self.size[1]) ) maxx = bounds.maxx minx = bounds.maxx - self.size[1] From 6942b5fbb0e3a866ba80255f828f8f3217485579 Mon Sep 17 00:00:00 2001 From: remtav Date: Tue, 28 Jun 2022 14:01:15 -0400 Subject: [PATCH 07/18] single.py: adapt gridgeosampler to sample beyond limit of ROI for a partial patch (to be padded) test_single.py: add tests for multiple limit cases (see issue #448) --- tests/samplers/test_single.py | 55 ++++++++++++++++++++---- torchgeo/samplers/single.py | 79 ++++++----------------------------- 2 files changed, 60 insertions(+), 74 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 2380bb119fd..ef7f88440c4 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -171,9 +171,13 @@ def sampler(self, dataset: CustomGeoDataset, request: SubRequest) -> GridGeoSamp def test_iter(self, sampler: GridGeoSampler) -> None: for query in sampler: - assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx - assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy - assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt + assert sampler.roi.minx <= query.minx + assert sampler.roi.miny <= query.miny + assert sampler.roi.mint <= query.mint + if query.maxx > sampler.roi.maxx: + assert (query.maxx - sampler.roi.maxx) < sampler.size[1] + if query.maxy > sampler.roi.maxy: + assert (query.maxy - sampler.roi.maxy) < sampler.size[0] assert math.isclose(query.maxx - query.minx, sampler.size[1]) assert math.isclose(query.maxy - query.miny, sampler.size[0]) @@ -182,11 +186,21 @@ def test_iter(self, sampler: GridGeoSampler) -> None: ) def test_len(self, sampler: GridGeoSampler) -> None: - rows = int((100 - sampler.size[0]) // sampler.stride[0]) + 1 - cols = int((100 - sampler.size[1]) // sampler.stride[1]) + 1 + rows = math.ceil((100 - sampler.size[0] + sampler.stride[0]) / sampler.stride[0]) + cols = math.ceil((100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[1]) length = rows * cols * 2 assert len(sampler) == length + def test_len_larger(self, sampler: GridGeoSampler) -> None: + entire_rows = (100 - sampler.size[0] + sampler.stride[0]) // sampler.stride[0] + entire_cols = (100 - sampler.size[1] + sampler.stride[1]) // sampler.stride[1] + leftover_row = (100 - sampler.size[0] + sampler.stride[0]) \ + / sampler.stride[0] - entire_rows + leftover_col = (100 - sampler.size[1] + sampler.stride[1]) \ + / sampler.stride[1] - entire_cols + assert len(sampler) == (entire_rows + math.ceil(leftover_row)) * \ + (entire_cols + math.ceil(leftover_col)) * 2 + def test_roi(self, dataset: CustomGeoDataset) -> None: roi = BoundingBox(0, 50, 200, 250, 400, 450) sampler = GridGeoSampler(dataset, 2, 1, roi=roi) @@ -194,12 +208,37 @@ def test_roi(self, dataset: CustomGeoDataset) -> None: assert query in roi def test_small_area(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 1, 0, 1, 0, 1)) + sampler = GridGeoSampler(ds, 2, 10) + assert len(sampler) == 1 + for bbox in sampler: + assert bbox == BoundingBox(minx=0.0, maxx=20.0, miny=0.0, maxy=20.0, mint=0.0, maxt=1.0) + + # TODO: skip patches with area=0 when two tiles are side-by-side with an overlapping edge face. + def test_tiles_side_by_side(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (20, 21, 20, 21, 20, 21)) + ds.index.insert(0, (0, 10, 10, 20, 0, 10)) sampler = GridGeoSampler(ds, 2, 10) - for _ in sampler: - continue + for bbox in sampler: + assert bbox.area > 0 + + def test_equal_area(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS) + assert len(sampler) == 1 + for bbox in sampler: + assert bbox == BoundingBox(minx=0.0, maxx=10.0, miny=0.0, maxy=10.0, mint=0.0, maxt=10.0) + + def test_larger_area(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 6, 0, 5, 0, 10)) + sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + assert len(sampler) == 2 + assert list(sampler)[0] == BoundingBox(minx=0.0, maxx=5.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0) + assert list(sampler)[1] == BoundingBox(minx=5.0, maxx=10.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0) @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index a18f06bb97b..6375db9968c 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,8 +4,7 @@ """TorchGeo samplers.""" import abc -import random -import warnings +import math from typing import Callable, Iterable, Iterator, Optional, Tuple, Union import torch @@ -109,12 +108,8 @@ def __init__( areas = [] for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) - if ( - bounds.maxx - bounds.minx >= self.size[1] - and bounds.maxy - bounds.miny >= self.size[0] - ): - self.hits.append(hit) - areas.append(bounds.area) + self.hits.append(hit) + areas.append(bounds.area) # torch.multinomial requires float probabilities > 0 self.areas = torch.tensor(areas, dtype=torch.float) @@ -198,32 +193,18 @@ def __init__( self.size = (self.size[0] * self.res, self.size[1] * self.res) self.stride = (self.stride[0] * self.res, self.stride[1] * self.res) - self.hits = [] - for hit in self.index.intersection(tuple(self.roi), objects=True): - bounds = BoundingBox(*hit.bounds) - if ( - bounds.maxx - bounds.minx > self.size[1] - and bounds.maxy - bounds.miny > self.size[0] - ): - self.hits.append(hit) + self.hits = list(self.index.intersection(tuple(self.roi), objects=True)) self.length: int = 0 for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = ( - int( - (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) - // self.stride[0] - ) - + 1 + # last patch samples outside the bounds + rows = math.ceil( + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) / self.stride[0] ) - cols = ( - int( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) - // self.stride[1] - ) - + 1 + cols = math.ceil( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) / self.stride[1] ) self.length += rows * cols @@ -237,19 +218,11 @@ def __iter__(self) -> Iterator[BoundingBox]: for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = ( - int( - (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) - // self.stride[0] - ) - + 1 + rows = math.ceil( + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) / self.stride[0] ) - cols = ( - int( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) - // self.stride[1] - ) - + 1 + cols = math.ceil( + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) / self.stride[1] ) mint = bounds.mint @@ -259,37 +232,11 @@ def __iter__(self) -> Iterator[BoundingBox]: for i in range(rows): miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] - if maxy > bounds.maxy: - last_stride_y = self.stride[0] - ( - miny - (bounds.maxy - self.size[0]) - ) - maxy = bounds.maxy - miny = bounds.maxy - self.size[0] - warnings.warn( - f"Max y coordinate of bounding box reaches passed y bounds of " - f"source tile" - f"Bounding box will be moved to set max y at source tile's max" - f" y. Stride will be adjusted" - f"from {self.stride[0]:.2f} to {last_stride_y:.2f}" - ) # For each column... for j in range(cols): minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] - if maxx > bounds.maxx: - last_stride_x = self.stride[1] - ( - minx - (bounds.maxx - self.size[1]) - ) - maxx = bounds.maxx - minx = bounds.maxx - self.size[1] - warnings.warn( - f"Max x coordinate of bounding box reaches passed x bounds" - f" of source tile" - f"Bounding box will be moved to set max x at source tile's" - f" max x. Stride will be adjusted" - f"from {self.stride[1]:.2f} to {last_stride_x:.2f}" - ) yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) From 138855b5961267ec64799d2eb7173d416ea3afa7 Mon Sep 17 00:00:00 2001 From: remtav Date: Tue, 28 Jun 2022 14:22:52 -0400 Subject: [PATCH 08/18] format for black and flake8 --- tests/samplers/test_single.py | 45 +++++++++++++++++++++++++---------- torchgeo/samplers/single.py | 6 +++-- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index ef7f88440c4..2e37370cb02 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -186,20 +186,30 @@ def test_iter(self, sampler: GridGeoSampler) -> None: ) def test_len(self, sampler: GridGeoSampler) -> None: - rows = math.ceil((100 - sampler.size[0] + sampler.stride[0]) / sampler.stride[0]) - cols = math.ceil((100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[1]) + rows = math.ceil( + (100 - sampler.size[0] + sampler.stride[0]) / sampler.stride[0] + ) + cols = math.ceil( + (100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[1] + ) length = rows * cols * 2 assert len(sampler) == length def test_len_larger(self, sampler: GridGeoSampler) -> None: entire_rows = (100 - sampler.size[0] + sampler.stride[0]) // sampler.stride[0] entire_cols = (100 - sampler.size[1] + sampler.stride[1]) // sampler.stride[1] - leftover_row = (100 - sampler.size[0] + sampler.stride[0]) \ - / sampler.stride[0] - entire_rows - leftover_col = (100 - sampler.size[1] + sampler.stride[1]) \ - / sampler.stride[1] - entire_cols - assert len(sampler) == (entire_rows + math.ceil(leftover_row)) * \ - (entire_cols + math.ceil(leftover_col)) * 2 + leftover_row = (100 - sampler.size[0] + sampler.stride[0]) / sampler.stride[ + 0 + ] - entire_rows + leftover_col = (100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[ + 1 + ] - entire_cols + assert ( + len(sampler) + == (entire_rows + math.ceil(leftover_row)) + * (entire_cols + math.ceil(leftover_col)) + * 2 + ) def test_roi(self, dataset: CustomGeoDataset) -> None: roi = BoundingBox(0, 50, 200, 250, 400, 450) @@ -213,9 +223,12 @@ def test_small_area(self) -> None: sampler = GridGeoSampler(ds, 2, 10) assert len(sampler) == 1 for bbox in sampler: - assert bbox == BoundingBox(minx=0.0, maxx=20.0, miny=0.0, maxy=20.0, mint=0.0, maxt=1.0) + assert bbox == BoundingBox( + minx=0.0, maxx=20.0, miny=0.0, maxy=20.0, mint=0.0, maxt=1.0 + ) - # TODO: skip patches with area=0 when two tiles are side-by-side with an overlapping edge face. + # TODO: skip patches with area=0 when two tiles are + # side-by-side with an overlapping edge face. def test_tiles_side_by_side(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) @@ -230,15 +243,21 @@ def test_equal_area(self) -> None: sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS) assert len(sampler) == 1 for bbox in sampler: - assert bbox == BoundingBox(minx=0.0, maxx=10.0, miny=0.0, maxy=10.0, mint=0.0, maxt=10.0) + assert bbox == BoundingBox( + minx=0.0, maxx=10.0, miny=0.0, maxy=10.0, mint=0.0, maxt=10.0 + ) def test_larger_area(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 6, 0, 5, 0, 10)) sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) assert len(sampler) == 2 - assert list(sampler)[0] == BoundingBox(minx=0.0, maxx=5.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0) - assert list(sampler)[1] == BoundingBox(minx=5.0, maxx=10.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0) + assert list(sampler)[0] == BoundingBox( + minx=0.0, maxx=5.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0 + ) + assert list(sampler)[1] == BoundingBox( + minx=5.0, maxx=10.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0 + ) @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 6375db9968c..f5ac57f76ef 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -201,10 +201,12 @@ def __init__( # last patch samples outside the bounds rows = math.ceil( - (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) / self.stride[0] + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) + / self.stride[0] ) cols = math.ceil( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) / self.stride[1] + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + / self.stride[1] ) self.length += rows * cols From 37d6055c8cbc2c2189b951e6044d4fe866c57c59 Mon Sep 17 00:00:00 2001 From: remtav Date: Tue, 28 Jun 2022 14:39:29 -0400 Subject: [PATCH 09/18] format for black and flake8 --- tests/samplers/test_single.py | 8 ++++---- torchgeo/samplers/single.py | 6 ++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 2e37370cb02..04c4b83e66c 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -205,10 +205,10 @@ def test_len_larger(self, sampler: GridGeoSampler) -> None: 1 ] - entire_cols assert ( - len(sampler) - == (entire_rows + math.ceil(leftover_row)) - * (entire_cols + math.ceil(leftover_col)) - * 2 + len(sampler) + == (entire_rows + math.ceil(leftover_row)) + * (entire_cols + math.ceil(leftover_col)) + * 2 ) def test_roi(self, dataset: CustomGeoDataset) -> None: diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index f5ac57f76ef..54f63246699 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -221,10 +221,12 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) rows = math.ceil( - (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) / self.stride[0] + (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) + / self.stride[0] ) cols = math.ceil( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) / self.stride[1] + (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) + / self.stride[1] ) mint = bounds.mint From b0ab3fa72841d4e1e067180906771ea39c480dc3 Mon Sep 17 00:00:00 2001 From: remtav Date: Tue, 28 Jun 2022 14:53:03 -0400 Subject: [PATCH 10/18] once again, format for black and flake8 --- tests/samplers/test_single.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 04c4b83e66c..aba6baa24a6 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -226,7 +226,7 @@ def test_small_area(self) -> None: assert bbox == BoundingBox( minx=0.0, maxx=20.0, miny=0.0, maxy=20.0, mint=0.0, maxt=1.0 ) - + # TODO: skip patches with area=0 when two tiles are # side-by-side with an overlapping edge face. def test_tiles_side_by_side(self) -> None: From af5a3d1cb80b8e821006e0349e18589cdc9c17e1 Mon Sep 17 00:00:00 2001 From: remtav Date: Mon, 29 Aug 2022 12:11:25 -0400 Subject: [PATCH 11/18] Revert "Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning" This reverts commit cb554c67 --- torchgeo/samplers/single.py | 40 ++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 95dee251390..cf2c2ba04f3 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,6 +5,7 @@ import abc import math +import warnings from typing import Callable, Iterable, Iterator, Optional, Tuple, Union import torch @@ -108,8 +109,12 @@ def __init__( areas = [] for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) - self.hits.append(hit) - areas.append(bounds.area) + if ( + bounds.maxx - bounds.minx >= self.size[1] + and bounds.maxy - bounds.miny >= self.size[0] + ): + self.hits.append(hit) + areas.append(bounds.area) # torch.multinomial requires float probabilities > 0 self.areas = torch.tensor(areas, dtype=torch.float) @@ -185,7 +190,7 @@ def __init__( .. versionchanged:: 0.3 Added ``units`` parameter, changed default to pixel units """ - super().__init__(dataset=dataset, roi=roi) + super().__init__(dataset, roi) self.size = _to_tuple(size) self.stride = _to_tuple(stride) @@ -193,7 +198,14 @@ def __init__( self.size = (self.size[0] * self.res, self.size[1] * self.res) self.stride = (self.stride[0] * self.res, self.stride[1] * self.res) - self.hits = list(self.index.intersection(tuple(self.roi), objects=True)) + self.hits = [] + for hit in self.index.intersection(tuple(self.roi), objects=True): + bounds = BoundingBox(*hit.bounds) + if ( + bounds.maxx - bounds.minx > self.size[1] + and bounds.maxy - bounds.miny > self.size[0] + ): + self.hits.append(hit) self.length: int = 0 for hit in self.hits: @@ -236,11 +248,29 @@ def __iter__(self) -> Iterator[BoundingBox]: for i in range(rows): miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] + if maxy > bounds.maxy: + last_stride_y = self.stride[0] - (miny - (bounds.maxy - self.size[0])) + maxy = bounds.maxy + miny = bounds.maxy - self.size[0] + warnings.warn( + f"Max y coordinate of bounding box reaches passed y bounds of source tile" + f"Bounding box will be moved to set max y at source tile's max y. Stride will be adjusted" + f"from {self.stride[0]:.2f} to {last_stride_y:.2f}" + ) # For each column... for j in range(cols): minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] + if maxx > bounds.maxx: + last_stride_x = self.stride[1] - (minx - (bounds.maxx - self.size[1])) + maxx = bounds.maxx + minx = bounds.maxx - self.size[1] + warnings.warn( + f"Max x coordinate of bounding box reaches passed x bounds of source tile" + f"Bounding box will be moved to set max x at source tile's max x. Stride will be adjusted" + f"from {self.stride[1]:.2f} to {last_stride_x:.2f}" + ) yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) @@ -260,7 +290,7 @@ class PreChippedGeoSampler(GeoSampler): and subclass :class:`~torchgeo.datasets.GeoDataset` but have already been pre-processed into :term:`chips `. - This sampler should not be used with :class:`~torchgeo.datasets.NonGeoDataset`. + This sampler should not be used with :class:`~torchgeo.datasets.VisionDataset`. You may encounter problems when using an :term:`ROI ` that partially intersects with one of the file bounding boxes, when using an :class:`~torchgeo.datasets.IntersectionDataset`, or when each file is in a From 0e61b1d11d95a42454ebf7b5403a51462ba6d868 Mon Sep 17 00:00:00 2001 From: remtav Date: Mon, 29 Aug 2022 16:40:17 -0400 Subject: [PATCH 12/18] adapt unit tests, remove warnings --- tests/samplers/test_single.py | 10 +++------- torchgeo/samplers/single.py | 18 +++--------------- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index aba6baa24a6..9019ad996c4 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -192,7 +192,7 @@ def test_len(self, sampler: GridGeoSampler) -> None: cols = math.ceil( (100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[1] ) - length = rows * cols * 2 + length = rows * cols * 2 # two items in dataset assert len(sampler) == length def test_len_larger(self, sampler: GridGeoSampler) -> None: @@ -221,11 +221,7 @@ def test_small_area(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 1, 0, 1, 0, 1)) sampler = GridGeoSampler(ds, 2, 10) - assert len(sampler) == 1 - for bbox in sampler: - assert bbox == BoundingBox( - minx=0.0, maxx=20.0, miny=0.0, maxy=20.0, mint=0.0, maxt=1.0 - ) + assert len(sampler) == 0 # TODO: skip patches with area=0 when two tiles are # side-by-side with an overlapping edge face. @@ -256,7 +252,7 @@ def test_larger_area(self) -> None: minx=0.0, maxx=5.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0 ) assert list(sampler)[1] == BoundingBox( - minx=5.0, maxx=10.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0 + minx=1.0, maxx=6.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0 ) @pytest.mark.slow diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index cf2c2ba04f3..f30034c73a8 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -202,8 +202,8 @@ def __init__( for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) if ( - bounds.maxx - bounds.minx > self.size[1] - and bounds.maxy - bounds.miny > self.size[0] + bounds.maxx - bounds.minx >= self.size[1] + and bounds.maxy - bounds.miny >= self.size[0] ): self.hits.append(hit) @@ -249,28 +249,16 @@ def __iter__(self) -> Iterator[BoundingBox]: miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] if maxy > bounds.maxy: - last_stride_y = self.stride[0] - (miny - (bounds.maxy - self.size[0])) maxy = bounds.maxy miny = bounds.maxy - self.size[0] - warnings.warn( - f"Max y coordinate of bounding box reaches passed y bounds of source tile" - f"Bounding box will be moved to set max y at source tile's max y. Stride will be adjusted" - f"from {self.stride[0]:.2f} to {last_stride_y:.2f}" - ) # For each column... for j in range(cols): minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] if maxx > bounds.maxx: - last_stride_x = self.stride[1] - (minx - (bounds.maxx - self.size[1])) maxx = bounds.maxx minx = bounds.maxx - self.size[1] - warnings.warn( - f"Max x coordinate of bounding box reaches passed x bounds of source tile" - f"Bounding box will be moved to set max x at source tile's max x. Stride will be adjusted" - f"from {self.stride[1]:.2f} to {last_stride_x:.2f}" - ) yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) @@ -290,7 +278,7 @@ class PreChippedGeoSampler(GeoSampler): and subclass :class:`~torchgeo.datasets.GeoDataset` but have already been pre-processed into :term:`chips `. - This sampler should not be used with :class:`~torchgeo.datasets.VisionDataset`. + This sampler should not be used with :class:`~torchgeo.datasets.NonGeoDataset`. You may encounter problems when using an :term:`ROI ` that partially intersects with one of the file bounding boxes, when using an :class:`~torchgeo.datasets.IntersectionDataset`, or when each file is in a From 6c623d83801926dc96d548580b2c581d3314d3b5 Mon Sep 17 00:00:00 2001 From: remtav Date: Tue, 30 Aug 2022 09:45:06 -0400 Subject: [PATCH 13/18] flake8: remove warnings import --- torchgeo/samplers/single.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index f30034c73a8..da65f9ce219 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,7 +5,6 @@ import abc import math -import warnings from typing import Callable, Iterable, Iterator, Optional, Tuple, Union import torch From 9d68d1e8d09cda7fb9ca675e40cb002ebe55d486 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 2 Sep 2022 20:00:37 -0700 Subject: [PATCH 14/18] Address some comments --- tests/samplers/test_single.py | 12 ++++-------- torchgeo/samplers/single.py | 1 - 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index db0e03bc4f9..7d0623ea563 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -171,13 +171,9 @@ def sampler(self, dataset: CustomGeoDataset, request: SubRequest) -> GridGeoSamp def test_iter(self, sampler: GridGeoSampler) -> None: for query in sampler: - assert sampler.roi.minx <= query.minx - assert sampler.roi.miny <= query.miny - assert sampler.roi.mint <= query.mint - if query.maxx > sampler.roi.maxx: - assert (query.maxx - sampler.roi.maxx) < sampler.size[1] - if query.maxy > sampler.roi.maxy: - assert (query.maxy - sampler.roi.maxy) < sampler.size[0] + assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx + assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy + assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt assert math.isclose(query.maxx - query.minx, sampler.size[1]) assert math.isclose(query.maxy - query.miny, sampler.size[0]) @@ -193,7 +189,7 @@ def test_len(self, sampler: GridGeoSampler) -> None: (100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[1] ) length = rows * cols * 2 # two items in dataset - assert len(sampler) == length + assert len(sampler) == length def test_len_larger(self, sampler: GridGeoSampler) -> None: entire_rows = (100 - sampler.size[0] + sampler.stride[0]) // sampler.stride[0] diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 125426d64dc..49bc4f3809f 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -210,7 +210,6 @@ def __init__( for hit in self.hits: bounds = BoundingBox(*hit.bounds) - # last patch samples outside the bounds rows = math.ceil( (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) / self.stride[0] From 8660f28a6273e4cf8c51cb031b6172995d207ff8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 2 Sep 2022 20:13:20 -0700 Subject: [PATCH 15/18] Simplify computation of # rows/cols --- tests/samplers/test_single.py | 24 ++---------------------- torchgeo/samplers/single.py | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 7d0623ea563..e6e27afedbc 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -182,31 +182,11 @@ def test_iter(self, sampler: GridGeoSampler) -> None: ) def test_len(self, sampler: GridGeoSampler) -> None: - rows = math.ceil( - (100 - sampler.size[0] + sampler.stride[0]) / sampler.stride[0] - ) - cols = math.ceil( - (100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[1] - ) + rows = math.ceil((100 - sampler.size[0]) / sampler.stride[0]) + 1 + cols = math.ceil((100 - sampler.size[1]) / sampler.stride[1]) + 1 length = rows * cols * 2 # two items in dataset assert len(sampler) == length - def test_len_larger(self, sampler: GridGeoSampler) -> None: - entire_rows = (100 - sampler.size[0] + sampler.stride[0]) // sampler.stride[0] - entire_cols = (100 - sampler.size[1] + sampler.stride[1]) // sampler.stride[1] - leftover_row = (100 - sampler.size[0] + sampler.stride[0]) / sampler.stride[ - 0 - ] - entire_rows - leftover_col = (100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[ - 1 - ] - entire_cols - assert ( - len(sampler) - == (entire_rows + math.ceil(leftover_row)) - * (entire_cols + math.ceil(leftover_col)) - * 2 - ) - def test_roi(self, dataset: CustomGeoDataset) -> None: roi = BoundingBox(0, 50, 200, 250, 400, 450) sampler = GridGeoSampler(dataset, 2, 1, roi=roi) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 49bc4f3809f..b0ee357bf48 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -210,13 +210,13 @@ def __init__( for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = math.ceil( - (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) - / self.stride[0] + rows = ( + math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0]) + + 1 ) - cols = math.ceil( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) - / self.stride[1] + cols = ( + math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1]) + + 1 ) self.length += rows * cols @@ -230,13 +230,13 @@ def __iter__(self) -> Iterator[BoundingBox]: for hit in self.hits: bounds = BoundingBox(*hit.bounds) - rows = math.ceil( - (bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) - / self.stride[0] + rows = ( + math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0]) + + 1 ) - cols = math.ceil( - (bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) - / self.stride[1] + cols = ( + math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1]) + + 1 ) mint = bounds.mint From 95369706f3472d96ceb76824e1e20010bb5bfdca Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 2 Sep 2022 20:32:48 -0700 Subject: [PATCH 16/18] Document this new feature --- torchgeo/samplers/single.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index b0ee357bf48..6251b129457 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -147,7 +147,7 @@ def __len__(self) -> int: class GridGeoSampler(GeoSampler): - """Samples elements in a grid-like fashion. + r"""Samples elements in a grid-like fashion. This is particularly useful during evaluation when you want to make predictions for an entire region of interest. You want to minimize the amount of redundant @@ -159,6 +159,20 @@ class GridGeoSampler(GeoSampler): The overlap between each chip (``chip_size - stride``) should be approximately equal to the `receptive field `_ of the CNN. + + Note that the stride of the final set of chips in each row/column may be adjusted so + that the entire :term:`tile` is sampled without exceeding the bounds of the dataset. + + Let :math:`i` be the size of the input tile. Let :math:`k` be the requested size of + the output patch. Let :math:`s` be the requested stride. Let :math:`o` be the number + of output rows/columns sampled from each tile. :math:`o` can then be computed as: + + .. math:: + + o = \lceil \frac{i - k}{s} \rceil + 1 + + This is almost identical to relationship 5 in + https://doi.org/10.48550/arXiv.1603.07285. However, since we want to """ def __init__( From 32d877e86bfec0587597882970b3fbe396307e8c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 2 Sep 2022 20:38:57 -0700 Subject: [PATCH 17/18] Fix size of ceiling symbol --- torchgeo/samplers/single.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 6251b129457..4b6d06ddf0b 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -169,7 +169,7 @@ class GridGeoSampler(GeoSampler): .. math:: - o = \lceil \frac{i - k}{s} \rceil + 1 + o = \left\lceil \frac{i - k}{s} \right\rceil + 1 This is almost identical to relationship 5 in https://doi.org/10.48550/arXiv.1603.07285. However, since we want to From a741129e91fdd6b1b0a612054c32a4999916b272 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 2 Sep 2022 20:58:52 -0700 Subject: [PATCH 18/18] Simplify tests --- tests/samplers/test_single.py | 21 +++++++-------------- torchgeo/samplers/single.py | 3 ++- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index e6e27afedbc..0528dfe78c4 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -199,8 +199,6 @@ def test_small_area(self) -> None: sampler = GridGeoSampler(ds, 2, 10) assert len(sampler) == 0 - # TODO: skip patches with area=0 when two tiles are - # side-by-side with an overlapping edge face. def test_tiles_side_by_side(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) @@ -209,27 +207,22 @@ def test_tiles_side_by_side(self) -> None: for bbox in sampler: assert bbox.area > 0 - def test_equal_area(self) -> None: + def test_integer_multiple(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS) + iterator = iter(sampler) assert len(sampler) == 1 - for bbox in sampler: - assert bbox == BoundingBox( - minx=0.0, maxx=10.0, miny=0.0, maxy=10.0, mint=0.0, maxt=10.0 - ) + assert next(iterator) == BoundingBox(0, 10, 0, 10, 0, 10) - def test_larger_area(self) -> None: + def test_float_multiple(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 6, 0, 5, 0, 10)) sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + iterator = iter(sampler) assert len(sampler) == 2 - assert list(sampler)[0] == BoundingBox( - minx=0.0, maxx=5.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0 - ) - assert list(sampler)[1] == BoundingBox( - minx=1.0, maxx=6.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0 - ) + assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) + assert next(iterator) == BoundingBox(1, 6, 0, 5, 0, 10) @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 4b6d06ddf0b..e063d9ecbd4 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -172,7 +172,8 @@ class GridGeoSampler(GeoSampler): o = \left\lceil \frac{i - k}{s} \right\rceil + 1 This is almost identical to relationship 5 in - https://doi.org/10.48550/arXiv.1603.07285. However, since we want to + https://doi.org/10.48550/arXiv.1603.07285. However, we use ceiling instead of floor + because we want to include the final remaining chip. """ def __init__(