Skip to content

Commit

Permalink
Add FTW DataModule (#2368)
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsleh authored Oct 27, 2024
1 parent b61788c commit 52fb6e3
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ FAIR1M

.. autoclass:: FAIR1MDataModule

Fields Of The World
^^^^^^^^^^^^^^^^^^^

.. autoclass:: FieldsOfTheWorldDataModule

FireRisk
^^^^^^^^

Expand Down
16 changes: 16 additions & 0 deletions tests/conf/ftw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 8
num_classes: 2
num_filters: 1
ignore_index: null
data:
class_path: FieldsOfTheWorldDataModule
init_args:
batch_size: 1
dict_kwargs:
root: 'tests/data/ftw'
3 changes: 3 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TestSemanticSegmentationTask:
'chesapeake_cvpr_7',
'deepglobelandcover',
'etci2021',
'ftw',
'geonrw',
'gid15',
'inria',
Expand Down Expand Up @@ -87,6 +88,8 @@ def test_trainer(
match name:
case 'chabud' | 'cabuar':
pytest.importorskip('h5py', minversion='3.6')
case 'ftw':
pytest.importorskip('pyarrow')
case 'landcoverai':
sha256 = (
'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .eurosat import EuroSAT100DataModule, EuroSATDataModule, EuroSATSpatialDataModule
from .fair1m import FAIR1MDataModule
from .fire_risk import FireRiskDataModule
from .ftw import FieldsOfTheWorldDataModule
from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
from .geonrw import GeoNRWDataModule
from .gid15 import GID15DataModule
Expand Down Expand Up @@ -79,6 +80,7 @@
'EuroSATSpatialDataModule',
'EuroSAT100DataModule',
'FAIR1MDataModule',
'FieldsOfTheWorldDataModule',
'FireRiskDataModule',
'GeoNRWDataModule',
'GID15DataModule',
Expand Down
86 changes: 86 additions & 0 deletions torchgeo/datamodules/ftw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""FTW datamodule."""

from typing import Any

import kornia.augmentation as K
import torch

from ..datasets import FieldsOfTheWorld
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


class FieldsOfTheWorldDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the FTW dataset.
.. versionadded:: 0.7
"""

mean = torch.tensor([0])
std = torch.tensor([3000])

def __init__(
self,
train_countries: list[str] = ['austria'],
val_countries: list[str] = ['austria'],
test_countries: list[str] = ['austria'],
batch_size: int = 64,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new FTWDataModule instance.
Args:
train_countries: List of countries to use for training.
val_countries: List of countries to use for validation.
test_countries: List of countries to use for testing.
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.FieldsOfTheWorld`.
Raises:
AssertionError: If 'countries' are specified in kwargs
"""
assert (
'countries' not in kwargs
), "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs"

super().__init__(FieldsOfTheWorld, batch_size, num_workers, **kwargs)

self.train_countries = train_countries
self.val_countries = val_countries
self.test_countries = test_countries

self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomRotation(p=0.5, degrees=90),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
data_keys=['image', 'mask'],
)
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
)

def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', or 'test'.
"""
if stage in ['fit', 'validate']:
self.train_dataset = FieldsOfTheWorld(
split='train', countries=self.train_countries, **self.kwargs
)
self.val_dataset = FieldsOfTheWorld(
split='val', countries=self.val_countries, **self.kwargs
)
if stage in ['test']:
self.test_dataset = FieldsOfTheWorld(
split='test', countries=self.test_countries, **self.kwargs
)

0 comments on commit 52fb6e3

Please sign in to comment.