From d30cb4995dd3ddb208fed5842c14db15853553b2 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 12 May 2023 10:59:22 -0500 Subject: [PATCH 1/2] GridGeoSampler: don't change stride of last patch --- tests/samplers/test_single.py | 27 +++++++++++++++++++++++---- torchgeo/samplers/single.py | 9 --------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index c17dd7da34a..f43db1ce195 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -162,7 +162,16 @@ def dataset(self) -> CustomGeoDataset: @pytest.fixture( scope="function", params=product( - [(8, 1), (6, 2), (4, 3), (2.5, 3), ((8, 6), (1, 2)), ((6, 4), (2, 3))], + [ + (8, 1), + (6, 2), + (4, 3), + (4, 4), + (2, 4), + (2.5, 3), + ((8, 6), (1, 2)), + ((6, 4), (2, 3)), + ], [Units.PIXELS, Units.CRS], ), ) @@ -172,8 +181,18 @@ def sampler(self, dataset: CustomGeoDataset, request: SubRequest) -> GridGeoSamp def test_iter(self, sampler: GridGeoSampler) -> None: for query in sampler: - assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx - assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy + assert ( + sampler.roi.minx + <= query.minx + <= query.maxx + < sampler.roi.maxx + sampler.stride[1] + ) + assert ( + sampler.roi.miny + <= query.miny + <= query.miny + < sampler.roi.maxy + sampler.stride[0] + ) assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt assert math.isclose(query.maxx - query.minx, sampler.size[1]) @@ -222,7 +241,7 @@ def test_float_multiple(self) -> None: iterator = iter(sampler) assert len(sampler) == 2 assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) - assert next(iterator) == BoundingBox(1, 6, 0, 5, 0, 10) + assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 26daa41e2d6..7251f4b274f 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -169,9 +169,6 @@ class GridGeoSampler(GeoSampler): The overlap between each chip (``chip_size - stride``) should be approximately equal to the `receptive field `_ of the CNN. - - Note that the stride of the final set of chips in each row/column may be adjusted so - that the entire :term:`tile` is sampled without exceeding the bounds of the dataset. """ def __init__( @@ -242,17 +239,11 @@ def __iter__(self) -> Iterator[BoundingBox]: for i in range(rows): miny = bounds.miny + i * self.stride[0] maxy = miny + self.size[0] - if maxy > bounds.maxy: - maxy = bounds.maxy - miny = bounds.maxy - self.size[0] # For each column... for j in range(cols): minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] - if maxx > bounds.maxx: - maxx = bounds.maxx - minx = bounds.maxx - self.size[1] yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) From 7d9a17d77d4e214faa640db9242439149f52c354 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 13 May 2023 13:07:18 -0500 Subject: [PATCH 2/2] Sample outside bounds of file --- torchgeo/datasets/geo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 9c440da896f..d4a81f73b46 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -487,6 +487,7 @@ def _merge_files( indexes=band_indexes, out_shape=out_shape, window=from_bounds(*bounds, src.transform), + boundless=True, ) else: dest, _ = rasterio.merge.merge(