Skip to content

Commit

Permalink
Adjust minx/miny with a smaller stride for the last sample per row/co…
Browse files Browse the repository at this point in the history
…l and issue warning
  • Loading branch information
remtav committed Jun 28, 2022
1 parent 7b5a92b commit cb554c6
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""TorchGeo samplers."""

import abc
import random
import warnings
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -188,7 +190,7 @@ def __init__(
.. versionchanged:: 0.3
Added ``units`` parameter, changed default to pixel units
"""
super().__init__(dataset, roi)
super().__init__(dataset=dataset, roi=roi, stride=stride, size=size)
self.size = _to_tuple(size)
self.stride = _to_tuple(stride)

Expand All @@ -209,8 +211,8 @@ def __init__(
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1
self.length += rows * cols

def __iter__(self) -> Iterator[BoundingBox]:
Expand All @@ -223,8 +225,8 @@ def __iter__(self) -> Iterator[BoundingBox]:
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = int((bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) // self.stride[1]) + 1

mint = bounds.mint
maxt = bounds.maxt
Expand All @@ -233,11 +235,29 @@ def __iter__(self) -> Iterator[BoundingBox]:
for i in range(rows):
miny = bounds.miny + i * self.stride[0]
maxy = miny + self.size[0]
if maxy > bounds.maxy:
last_stride_y = self.stride[0] - (miny - (bounds.maxy - self.size[0]))
maxy = bounds.maxy
miny = bounds.maxy - self.size[0]
warnings.warn(
f"Max y coordinate of bounding box reaches passed y bounds of source tile"
f"Bounding box will be moved to set max y at source tile's max y. Stride will be adjusted"
f"from {self.stride[0]:.2f} to {last_stride_y:.2f}"
)

# For each column...
for j in range(cols):
minx = bounds.minx + j * self.stride[1]
maxx = minx + self.size[1]
if maxx > bounds.maxx:
last_stride_x = self.stride[1] - (minx - (bounds.maxx - self.size[1]))
maxx = bounds.maxx
minx = bounds.maxx - self.size[1]
warnings.warn(
f"Max x coordinate of bounding box reaches passed x bounds of source tile"
f"Bounding box will be moved to set max x at source tile's max x. Stride will be adjusted"
f"from {self.stride[1]:.2f} to {last_stride_x:.2f}"
)

yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)

Expand Down

0 comments on commit cb554c6

Please sign in to comment.