Skip to content

Commit

Permalink
added uncompressed test data folder and datamodule test. added versio…
Browse files Browse the repository at this point in the history
…nadded and fixed _verify
  • Loading branch information
lccol committed Dec 10, 2024
1 parent 344115c commit 41c118d
Show file tree
Hide file tree
Showing 31 changed files with 96 additions and 139 deletions.
18 changes: 18 additions & 0 deletions tests/conf/mmflood.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 3
num_classes: 2
num_filters: 1
data:
class_path: MMFloodDataModule
init_args:
batch_size: 1
dict_kwargs:
root: 'tests/data/mmflood'
patch_size: 8
normalization: 'median'
include_dem: True
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 0 additions & 2 deletions tests/data/mmflood/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import json
import os
import shutil
import tarfile

import numpy as np
Expand Down Expand Up @@ -113,7 +112,6 @@ def generate_folders_and_metadata(datapath: str, metadatapath: str) -> None:
generate_tar_gz(src='activations', dst='activations.tar.gz')
split_tar(path='activations.tar.gz', dst='.', nparts=2)
os.remove('activations.tar.gz')
shutil.rmtree('activations')
with open(os.path.join(metadatapath, 'activations.json'), 'w') as fp:
json.dump(metadata, fp)

Expand Down
72 changes: 0 additions & 72 deletions tests/datamodules/test_mmflood.py

This file was deleted.

23 changes: 0 additions & 23 deletions tests/datasets/test_mmflood.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,3 @@ def test_invalid_query(self, dataset: MMFlood) -> None:
IndexError, match='query: .* not found in index with bounds:'
):
dataset[query]

def test_check_folders(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
class MockMMFlood(MMFlood):
def _load_folders(
self, check_folders: bool = False
) -> list[dict[str, str]]:
return super()._load_folders(check_folders=False)

dataset_root = os.path.join('tests', 'data', 'mmflood/')
url = os.path.join(dataset_root)

monkeypatch.setattr(MMFlood, 'url', url)
monkeypatch.setattr(MMFlood, '_nparts', 2)

_ = MockMMFlood(
tmp_path,
split='train',
include_dem=True,
transforms=nn.Identity(),
download=True,
checksum=True,
)
return
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class TestSemanticSegmentationTask:
'landcoverai',
'landcoverai100',
'loveda',
'mmflood',
'naipchesapeake',
'potsdam2d',
'sen12ms_all',
Expand Down
8 changes: 4 additions & 4 deletions torchgeo/datamodules/mmflood.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@


class MMFloodDataModule(GeoDataModule):
"""LightningDataModule implementation for the MMFlood dataset."""
"""LightningDataModule implementation for the MMFlood dataset.
.. versionadded:: 0.7
"""

# Computed over train set
mean = torch.tensor([0.1785585, 0.03574104, 168.45529])
Expand Down Expand Up @@ -75,8 +78,6 @@ def __init__(
K.Normalize(avg, self.std), keepdim=True, data_keys=None
)

return

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand All @@ -98,4 +99,3 @@ def setup(self, stage: str) -> None:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)
return
111 changes: 73 additions & 38 deletions torchgeo/datasets/mmflood.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class MMFlood(RasterDataset):
If you use this dataset in your research, please cite the following paper:
* https://doi.org/10.1109/ACCESS.2022.3205419
.. versionadded:: 0.7
"""

url = 'https://huggingface.co/datasets/links-ads/mmflood/resolve/24ca097306c9e50ad0711903c11e1ba13ea1bedc/'
Expand Down Expand Up @@ -122,12 +124,11 @@ def __init__(
# self.image_files, self.label_files, self.dem_files attributes
self._verify()
self.metadata_df = self._load_metadata()
self.folders = self._load_folders(check_folders=True)
self.folders = self._load_folders()
paths = [x['s1_raw'] for x in self.folders]

# Build the index
super().__init__(paths=paths, crs=crs, transforms=transforms, cache=cache)
return

def _merge_tar_files(self) -> None:
"""Merge part tar gz files."""
Expand All @@ -143,7 +144,6 @@ def _merge_tar_files(self) -> None:

with open(part_path, 'rb') as part_fp:
dst_fp.write(part_fp.read())
return

def __getitem__(self, query: BoundingBox) -> dict[str, Tensor]:
"""Retrieve image/mask and metadata indexed by query.
Expand Down Expand Up @@ -194,25 +194,35 @@ def _load_metadata(self) -> pd.DataFrame:
).transpose()
return df

def _load_folders(self, check_folders: bool = False) -> list[dict[str, str]]:
"""Load folder paths.
def _load_tif_files(
self, check_folders: bool = False, load_all: bool = False
) -> dict[str, list[str]]:
"""Load paths of all tif files for Sentinel-1, DEM and masks.
Args:
check_folders: if True, verify pairings of all s1, dem and mask data across all the folders
check_folders: if True, verifies pairings of all s1, dem and mask data across all the folders
load_all: if True, loads all tif files contained in the "activations" folder in the root folder specified. Otherwise, only acquisitions for the given split are loaded.
Returns:
list of dicts of s1, dem and masks folder paths
dict containing list of paths, with 'image', 'dem' and 'mask' as keys
"""
paths = {}
dirpath = os.path.join(self.root, self.metadata['directory'])
# initialize tif file lists containing masks, DEM and S1_raw data
folders = self.metadata_df[
self.metadata_df['subset'] == self.split
].index.tolist()
if load_all:
# Get all directories
folders = os.listdir(dirpath)
else:
# Assemble regex for glob
folders = (
self.metadata_df[self.metadata_df['subset'] == self.split].index + '-*'
).tolist()

image_files = []
mask_files = []
dem_files = []
for f in folders:
path = os.path.join(self.root, self.metadata['directory'], f'{f}-*')
path = os.path.join(self.root, self.metadata['directory'], f)
image_files += glob(os.path.join(path, 's1_raw', '*.tif'))
mask_files += glob(os.path.join(path, 'mask', '*.tif'))
dem_files += glob(os.path.join(path, 'DEM', '*.tif'))
Expand All @@ -221,34 +231,42 @@ def _load_folders(self, check_folders: bool = False) -> list[dict[str, str]]:
mask_files = sorted(mask_files)
dem_files = sorted(dem_files)

paths['image'] = image_files
paths['mask'] = mask_files
paths['dem'] = dem_files

# Verify image, dem and mask lengths
assert (
len(image_files) > 0
len(paths['image']) > 0
), f'No images found, is the given path correct? ({self.root!s})'
assert (
len(image_files) == len(mask_files)
), f'Length mismatch between tiles and masks: {len(image_files)} != {len(mask_files)}'
assert len(image_files) == len(
dem_files
len(paths['image']) == len(paths['mask'])
), f'Length mismatch between tiles and masks: {len(paths['image'])} != {len(paths['mask'])}'
assert len(paths['image']) == len(
paths['dem']
), 'Length mismatch between tiles and DEMs'

if check_folders:
# Verify image, dem and mask pairings
self._verify_pairings(paths['image'], paths['dem'], paths['mask'])

return paths

def _load_folders(self) -> list[dict[str, str]]:
"""Load folder paths.
Returns:
list of dicts of s1, dem and masks folder paths
"""
paths = self._load_tif_files(check_folders=False, load_all=False)

res_folders = [
{'s1_raw': img_path, 'mask': mask_path, 'dem': dem_path}
for img_path, mask_path, dem_path in zip(image_files, mask_files, dem_files)
for img_path, mask_path, dem_path in zip(
paths['image'], paths['mask'], paths['dem']
)
]

if not check_folders:
return res_folders

# Verify image, dem and mask pairings
for image, mask, dem in zip(image_files, mask_files, dem_files):
image_tile = pathlib.Path(image).stem
mask_tile = pathlib.Path(mask).stem
dem_tile = pathlib.Path(dem).stem
assert (
image_tile == mask_tile == dem_tile
), f'Filenames not matching: image {image_tile}; mask {mask_tile}; dem {dem_tile}'

return res_folders

def _load_image(self, index: list[int], query: BoundingBox) -> Tensor:
Expand Down Expand Up @@ -299,7 +317,7 @@ def _load_target(self, index: list[int], query: BoundingBox) -> Tensor:
the target mask
"""
tensor = self._load_tif(index, modality='mask', query=query).type(torch.uint8)
return tensor.squeeze(dim=0)
return tensor.long().squeeze(dim=0)

def _download(self) -> None:
"""Download the dataset."""
Expand All @@ -323,32 +341,49 @@ def _check_and_download(filename: str, url: str) -> None:
_check_and_download(
self.metadata['metadata_file'], self.url + self.metadata['metadata_file']
)
return

def _extract(self) -> None:
"""Extract the dataset.
Args:
filepath: path to file to be extracted
"""
"""Extract the dataset."""
filepath = os.path.join(self.root, self.metadata['filename'])
if str(filepath).endswith('.tar.gz'):
extract_archive(filepath)
return

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
dirpath = os.path.join(self.root, self.metadata['directory'])
metadata_filepath = os.path.join(self.root, self.metadata['metadata_file'])
# Check if both metadata file and directory exist
if os.path.isdir(dirpath) and os.path.isfile(metadata_filepath):
# Check pairings of all files
_ = self._load_tif_files(check_folders=True, load_all=True)
return
if not self.download:
raise DatasetNotFoundError(self)
self._download()
self._merge_tar_files()
self._extract()
return

def _verify_pairings(
self, s1_paths: list[str], dem_paths: list[str], mask_paths: list[str]
) -> None:
"""Verify all pairings of Sentinel-1, DEM and mask tif files. All inputs must be sorted.
Args:
s1_paths: list of paths of Sentinel-1 tif files
dem_paths: list of paths of DEM tif files
mask_paths: list of paths of mask tif files
"""
assert (
len(s1_paths) == len(dem_paths) == len(mask_paths)
), f'Lengths of s1, dem and mask files do not match! ({len(s1_paths)}, {len(dem_paths)}, {len(mask_paths)})'

for image, mask, dem in zip(s1_paths, mask_paths, dem_paths):
image_tile = pathlib.Path(image).stem
mask_tile = pathlib.Path(mask).stem
dem_tile = pathlib.Path(dem).stem
assert (
image_tile == mask_tile == dem_tile
), f'Filenames not matching: image {image_tile}; mask {mask_tile}; dem {dem_tile}'

def plot(
self,
Expand Down

0 comments on commit 41c118d

Please sign in to comment.