Skip to content

Commit

Permalink
Revert "Adjust minx/miny with a smaller stride for the last sample pe…
Browse files Browse the repository at this point in the history
…r row/col and issue warning"

This reverts commit cb554c6
  • Loading branch information
remtav committed Aug 29, 2022
1 parent 720cf5b commit af5a3d1
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import abc
import math
import warnings
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -108,8 +109,12 @@ def __init__(
areas = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
self.hits.append(hit)
areas.append(bounds.area)
if (
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
self.hits.append(hit)
areas.append(bounds.area)

# torch.multinomial requires float probabilities > 0
self.areas = torch.tensor(areas, dtype=torch.float)
Expand Down Expand Up @@ -185,15 +190,22 @@ def __init__(
.. versionchanged:: 0.3
Added ``units`` parameter, changed default to pixel units
"""
super().__init__(dataset=dataset, roi=roi)
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.stride = _to_tuple(stride)

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)
self.stride = (self.stride[0] * self.res, self.stride[1] * self.res)

self.hits = list(self.index.intersection(tuple(self.roi), objects=True))
self.hits = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx > self.size[1]
and bounds.maxy - bounds.miny > self.size[0]
):
self.hits.append(hit)

self.length: int = 0
for hit in self.hits:
Expand Down Expand Up @@ -236,11 +248,29 @@ def __iter__(self) -> Iterator[BoundingBox]:
for i in range(rows):
miny = bounds.miny + i * self.stride[0]
maxy = miny + self.size[0]
if maxy > bounds.maxy:
last_stride_y = self.stride[0] - (miny - (bounds.maxy - self.size[0]))
maxy = bounds.maxy
miny = bounds.maxy - self.size[0]
warnings.warn(
f"Max y coordinate of bounding box reaches passed y bounds of source tile"
f"Bounding box will be moved to set max y at source tile's max y. Stride will be adjusted"
f"from {self.stride[0]:.2f} to {last_stride_y:.2f}"
)

# For each column...
for j in range(cols):
minx = bounds.minx + j * self.stride[1]
maxx = minx + self.size[1]
if maxx > bounds.maxx:
last_stride_x = self.stride[1] - (minx - (bounds.maxx - self.size[1]))
maxx = bounds.maxx
minx = bounds.maxx - self.size[1]
warnings.warn(
f"Max x coordinate of bounding box reaches passed x bounds of source tile"
f"Bounding box will be moved to set max x at source tile's max x. Stride will be adjusted"
f"from {self.stride[1]:.2f} to {last_stride_x:.2f}"
)

yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)

Expand All @@ -260,7 +290,7 @@ class PreChippedGeoSampler(GeoSampler):
and subclass :class:`~torchgeo.datasets.GeoDataset` but have already been
pre-processed into :term:`chips <chip>`.
This sampler should not be used with :class:`~torchgeo.datasets.NonGeoDataset`.
This sampler should not be used with :class:`~torchgeo.datasets.VisionDataset`.
You may encounter problems when using an :term:`ROI <region of interest (ROI)>`
that partially intersects with one of the file bounding boxes, when using an
:class:`~torchgeo.datasets.IntersectionDataset`, or when each file is in a
Expand Down

0 comments on commit af5a3d1

Please sign in to comment.