Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add splitting utilities for GeoDatasets #866

Merged
merged 51 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
85810c0
add extent_crop to BoundingBox
pmandiola Oct 24, 2022
45f963b
add extent_crop param to RasterDataset
pmandiola Oct 24, 2022
cf8e824
train_test_split function
pmandiola Oct 27, 2022
a50c983
minor changes
pmandiola Nov 2, 2022
85d5718
fix circular import
pmandiola Nov 2, 2022
12fc826
Merge branch 'main' into feature/split_raster_datasets
pmandiola Dec 14, 2022
d343c3b
remove extent_crop
pmandiola Dec 20, 2022
bd65c85
move existing functions to new file
pmandiola Dec 20, 2022
6f694e8
refactor random_nongeo_split
pmandiola Dec 20, 2022
9a80943
refactor random_bbox_splitting
pmandiola Dec 20, 2022
8c54b84
add roi_split
pmandiola Dec 20, 2022
6698745
add random_bbox_assignment
pmandiola Dec 20, 2022
4de89b4
add input checks
pmandiola Dec 20, 2022
38a8b1b
fix input type
pmandiola Dec 21, 2022
3ca7c5f
minor reorder
pmandiola Dec 21, 2022
c3ff112
add tests
pmandiola Dec 21, 2022
49fc2d1
add non-overlapping test
pmandiola Dec 21, 2022
d8ad1b4
more tests
pmandiola Dec 21, 2022
8c80c9b
fix tests
pmandiola Dec 21, 2022
d676ed1
additional tests
pmandiola Dec 22, 2022
503478c
check overlapping rois
pmandiola Dec 22, 2022
4a96e40
add time_series_split with tests
pmandiola Dec 29, 2022
8842d78
fix random_nongeo_split to work with fractions in torch 1.9
pmandiola Dec 29, 2022
c9b8b19
modify random_nongeo_split test for coverage
pmandiola Dec 29, 2022
54a1781
add random_grid_cell_assignment with tests
pmandiola Jan 3, 2023
2997b95
add test
pmandiola Jan 3, 2023
a27d09d
insert object into new indexes
pmandiola Jan 4, 2023
3a805c0
check grid_size
pmandiola Jan 4, 2023
56a3692
better tests
pmandiola Jan 4, 2023
6cda093
small type fix
pmandiola Jan 4, 2023
5d87c71
fix again
pmandiola Jan 4, 2023
7adca07
rm .DS_Store
pmandiola Jan 4, 2023
e19c7ad
fix typo
pmandiola Feb 15, 2023
1f99166
Merge branch 'microsoft:main' into feature/split_raster_datasets
pmandiola Feb 15, 2023
cab469f
bump version added
pmandiola Feb 15, 2023
771f885
add to __init__
pmandiola Feb 15, 2023
41f7308
add to datasets.rst
pmandiola Feb 15, 2023
71f7503
use accumulate from itertools
pmandiola Feb 15, 2023
91445b7
clarify grid_size
pmandiola Feb 15, 2023
4fcd141
remove random_nongeo_split
pmandiola Feb 15, 2023
fa513b5
remove _create_geodataset_like
pmandiola Feb 15, 2023
3f46921
black reformatting
pmandiola Feb 15, 2023
5d2f866
Update tests/datasets/test_splits.py
pmandiola Feb 20, 2023
28c24fb
change import
pmandiola Feb 20, 2023
358a679
docstrings
pmandiola Feb 20, 2023
3610c1f
undo intersection change
pmandiola Feb 20, 2023
d589ed1
use microsecond
pmandiola Feb 20, 2023
d6fc778
use isclose
pmandiola Feb 20, 2023
9c9fd26
black
pmandiola Feb 20, 2023
15ffb24
fix typing
pmandiola Feb 21, 2023
0182180
add comments
pmandiola Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,12 @@ Collation Functions
.. autofunction:: concat_samples
.. autofunction:: merge_samples
.. autofunction:: unbind_samples

Splitting Functions
^^^^^^^^^^^^^^^^^^^

.. autofunction:: random_bbox_assignment
.. autofunction:: random_bbox_splitting
.. autofunction:: random_grid_cell_assignment
.. autofunction:: roi_split
.. autofunction:: time_series_split
322 changes: 322 additions & 0 deletions tests/datasets/test_splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from math import floor, isclose
from typing import Any, Dict, List, Sequence, Tuple, Union

import pytest
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
GeoDataset,
random_bbox_assignment,
random_bbox_splitting,
random_grid_cell_assignment,
roi_split,
time_series_split,
)


class CustomGeoDataset(GeoDataset):
def __init__(
self,
items: List[Tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), "")],
crs: CRS = CRS.from_epsg(3005),
res: float = 1,
) -> None:
super().__init__()
for box, content in items:
self.index.insert(0, tuple(box), content)
self._crs = crs
self.res = res

def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
hits = self.index.intersection(tuple(query), objects=True)
hit = next(iter(hits))
return {"content": hit.object}


@pytest.mark.parametrize(
"lengths,expected_lengths",
[
# List of lengths
([2, 1, 1], [2, 1, 1]),
# List of fractions (with remainder)
([1 / 3, 1 / 3, 1 / 3], [2, 1, 1]),
],
)
def test_random_bbox_assignment(
lengths: Sequence[Union[int, float]], expected_lengths: Sequence[int]
) -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 0), "a"),
(BoundingBox(1, 2, 0, 1, 0, 0), "b"),
(BoundingBox(2, 3, 0, 1, 0, 0), "c"),
(BoundingBox(3, 4, 0, 1, 0, 0), "d"),
]
)

train_ds, val_ds, test_ds = random_bbox_assignment(ds, lengths)

# Check datasets lengths
assert len(train_ds) == expected_lengths[0]
assert len(val_ds) == expected_lengths[1]
assert len(test_ds) == expected_lengths[2]

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)

# Union equals original
pmandiola marked this conversation as resolved.
Show resolved Hide resolved
assert (train_ds | val_ds | test_ds).bounds == ds.bounds

# Test __getitem__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)


def test_random_bbox_assignment_invalid_inputs() -> None:
with pytest.raises(
ValueError,
match="Sum of input lengths must equal 1 or the length of dataset's index.",
):
random_bbox_assignment(CustomGeoDataset(), lengths=[2, 2, 1])
with pytest.raises(
ValueError, match="All items in input lengths must be greater than 0."
):
random_bbox_assignment(CustomGeoDataset(), lengths=[1 / 2, 3 / 4, -1 / 4])


def _get_total_area(dataset: GeoDataset) -> float:
total_area = 0.0
for hit in dataset.index.intersection(dataset.index.bounds, objects=True):
total_area += BoundingBox(*hit.bounds).area

return total_area


def test_random_bbox_splitting() -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 0), "a"),
(BoundingBox(1, 2, 0, 1, 0, 0), "b"),
(BoundingBox(2, 3, 0, 1, 0, 0), "c"),
(BoundingBox(3, 4, 0, 1, 0, 0), "d"),
]
)

ds_area = _get_total_area(ds)

train_ds, val_ds, test_ds = random_bbox_splitting(
ds, fractions=[1 / 2, 1 / 4, 1 / 4]
)
train_ds_area = _get_total_area(train_ds)
val_ds_area = _get_total_area(val_ds)
test_ds_area = _get_total_area(test_ds)

# Check datasets areas
assert train_ds_area == ds_area / 2
assert val_ds_area == ds_area / 4
assert test_ds_area == ds_area / 4

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), ds_area)

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)

# Test invalid input fractions
with pytest.raises(ValueError, match="Sum of input fractions must equal 1."):
random_bbox_splitting(ds, fractions=[1 / 2, 1 / 3, 1 / 4])
with pytest.raises(
ValueError, match="All items in input fractions must be greater than 0."
):
random_bbox_splitting(ds, fractions=[1 / 2, 3 / 4, -1 / 4])


def test_random_grid_cell_assignment() -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 12, 0, 12, 0, 0), "a"),
(BoundingBox(12, 24, 0, 12, 0, 0), "b"),
]
)

train_ds, val_ds, test_ds = random_grid_cell_assignment(
ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=5
)

# Check datasets lengths
assert len(train_ds) == 1 / 2 * 2 * 5**2 + 1
assert len(val_ds) == floor(1 / 4 * 2 * 5**2)
assert len(test_ds) == floor(1 / 4 * 2 * 5**2)

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)

# Test invalid input fractions
with pytest.raises(ValueError, match="Sum of input fractions must equal 1."):
random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 3, 1 / 4])
with pytest.raises(
ValueError, match="All items in input fractions must be greater than 0."
):
random_grid_cell_assignment(ds, fractions=[1 / 2, 3 / 4, -1 / 4])
with pytest.raises(ValueError, match="Input grid_size must be greater than 1."):
random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=1)


def test_roi_split() -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 0), "a"),
(BoundingBox(1, 2, 0, 1, 0, 0), "b"),
(BoundingBox(2, 3, 0, 1, 0, 0), "c"),
(BoundingBox(3, 4, 0, 1, 0, 0), "d"),
]
)

train_ds, val_ds, test_ds = roi_split(
ds,
rois=[
BoundingBox(0, 2, 0, 1, 0, 0),
BoundingBox(2, 3.5, 0, 1, 0, 0),
BoundingBox(3.5, 4, 0, 1, 0, 0),
],
)

# Check datasets lengths
assert len(train_ds) == 2
assert len(val_ds) == 2
assert len(test_ds) == 1

# No overlap
assert len(train_ds & val_ds) == 0 or isclose(_get_total_area(train_ds & val_ds), 0)
assert len(val_ds & test_ds) == 0 or isclose(_get_total_area(val_ds & test_ds), 0)
assert len(test_ds & train_ds) == 0 or isclose(
_get_total_area(test_ds & train_ds), 0
)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(_get_total_area(train_ds | val_ds | test_ds), _get_total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)

# Test invalid input rois
with pytest.raises(ValueError, match="ROIs in input rois can't overlap."):
roi_split(
ds, rois=[BoundingBox(0, 2, 0, 1, 0, 0), BoundingBox(1, 3, 0, 1, 0, 0)]
)


@pytest.mark.parametrize(
"lengths,expected_lengths",
[
# List of timestamps
([(0, 20), (20, 35), (35, 40)], [2, 2, 1]),
# List of lengths
([20, 15, 5], [2, 2, 1]),
# List of fractions (with remainder)
([1 / 2, 3 / 8, 1 / 8], [2, 2, 1]),
],
)
def test_time_series_split(
lengths: Sequence[Union[Tuple[int, int], int, float]],
expected_lengths: Sequence[int],
) -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 10), "a"),
(BoundingBox(0, 1, 0, 1, 10, 20), "b"),
(BoundingBox(0, 1, 0, 1, 20, 30), "c"),
(BoundingBox(0, 1, 0, 1, 30, 40), "d"),
]
)

train_ds, val_ds, test_ds = time_series_split(ds, lengths)

# Check datasets lengths
assert len(train_ds) == expected_lengths[0]
assert len(val_ds) == expected_lengths[1]
assert len(test_ds) == expected_lengths[2]

# No overlap
assert len(train_ds & val_ds) == 0
assert len(val_ds & test_ds) == 0
assert len(test_ds & train_ds) == 0

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)


def test_time_series_split_invalid_input() -> None:
with pytest.raises(
ValueError,
match="Pairs of timestamps in lengths must have end greater than start.",
):
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (35, 20), (35, 40)])

with pytest.raises(
ValueError,
match="Pairs of timestamps in lengths must cover dataset's time bounds.",
):
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 35)])

with pytest.raises(
ValueError,
match="Pairs of timestamps in lengths can't be out of dataset's time bounds.",
):
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 45)])

with pytest.raises(
ValueError, match="Pairs of timestamps in lengths can't overlap."
):
time_series_split(CustomGeoDataset(), lengths=[(0, 10), (10, 20), (15, 40)])

with pytest.raises(
ValueError,
match="Sum of input lengths must equal 1 or the dataset's time length.",
):
time_series_split(CustomGeoDataset(), lengths=[1 / 2, 1 / 2, 1 / 2])

with pytest.raises(
ValueError, match="All items in input lengths must be greater than 0."
):
time_series_split(CustomGeoDataset(), lengths=[20, 25, -5])
29 changes: 29 additions & 0 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,35 @@ def test_intersects(
bbox2 = BoundingBox(*test_input)
assert bbox1.intersects(bbox2) == bbox2.intersects(bbox1) == expected

@pytest.mark.parametrize(
"proportion,horizontal,expected",
[
(0.25, True, ((0, 0.25, 0, 1, 0, 1), (0.25, 1, 0, 1, 0, 1))),
(0.25, False, ((0, 1, 0, 0.25, 0, 1), (0, 1, 0.25, 1, 0, 1))),
],
)
def test_split(
self,
proportion: float,
horizontal: bool,
expected: Tuple[
Tuple[float, float, float, float, float, float],
Tuple[float, float, float, float, float, float],
],
) -> None:
bbox = BoundingBox(0, 1, 0, 1, 0, 1)
bbox1, bbox2 = bbox.split(proportion, horizontal)
assert bbox1 == BoundingBox(*expected[0])
assert bbox2 == BoundingBox(*expected[1])
assert bbox1 | bbox2 == bbox

def test_split_error(self) -> None:
bbox = BoundingBox(0, 1, 0, 1, 0, 1)
with pytest.raises(
ValueError, match="Input proportion must be between 0 and 1."
):
bbox.split(1.5)

def test_picklable(self) -> None:
bbox = BoundingBox(0, 1, 2, 3, 4, 5)
x = pickle.dumps(bbox)
Expand Down
13 changes: 13 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@
SpaceNet6,
SpaceNet7,
)
from .splits import (
random_bbox_assignment,
random_bbox_splitting,
random_grid_cell_assignment,
roi_split,
time_series_split,
)
from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
Expand Down Expand Up @@ -207,4 +214,10 @@
"merge_samples",
"stack_samples",
"unbind_samples",
# Splits
"random_bbox_assignment",
"random_bbox_splitting",
"random_grid_cell_assignment",
"roi_split",
"time_series_split",
)
Loading