-
Notifications
You must be signed in to change notification settings - Fork 375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MMFlood dataset #2450
Open
lccol
wants to merge
7
commits into
microsoft:main
Choose a base branch
from
lccol:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add MMFlood dataset #2450
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
76a6941
Add MMFlood dataset
lccol 19ee181
Added tests for MMFloodDataModule
lccol 344115c
Merge branch 'main' into main
lccol 41c118d
added uncompressed test data folder and datamodule test. added versio…
lccol 9d0f76f
fix assertion
lccol d9bb5ef
Merge branch 'main' into main
lccol 37ac4ab
updated docstring
lccol File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
model: | ||
class_path: SemanticSegmentationTask | ||
init_args: | ||
loss: 'ce' | ||
model: 'unet' | ||
backbone: 'resnet18' | ||
in_channels: 3 | ||
num_classes: 2 | ||
num_filters: 1 | ||
data: | ||
class_path: MMFloodDataModule | ||
init_args: | ||
batch_size: 1 | ||
dict_kwargs: | ||
root: 'tests/data/mmflood' | ||
patch_size: 8 | ||
normalization: 'median' | ||
include_dem: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"EMSR000": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR000_00"]}, "EMSR001": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR001_00"]}, "EMSR003": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "val", "delineations": ["EMSR003_00"]}, "EMSR004": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "test", "delineations": ["EMSR004_00"]}} |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import json | ||
import os | ||
import tarfile | ||
|
||
import numpy as np | ||
import rasterio | ||
from rasterio.crs import CRS | ||
from rasterio.transform import Affine | ||
|
||
|
||
def generate_data(path: str, filename: str, height: int, width: int) -> None: | ||
MAX_VALUE = 1000.0 | ||
MIN_VALUE = 0.0 | ||
RANGE = MAX_VALUE - MIN_VALUE | ||
FOLDERS = ['s1_raw', 'DEM', 'mask'] | ||
Comment on lines
+15
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lowercase would be better for local variables. Just note that |
||
profile = { | ||
'driver': 'GTiff', | ||
'dtype': 'float32', | ||
'nodata': None, | ||
'crs': CRS.from_epsg(4326), | ||
'transform': Affine( | ||
0.0001287974837883981, | ||
0.0, | ||
14.438064999669106, | ||
0.0, | ||
-8.989523639880024e-05, | ||
45.71617928533084, | ||
), | ||
'blockysize': 1, | ||
'tiled': False, | ||
'interleave': 'pixel', | ||
'height': height, | ||
'width': width, | ||
} | ||
data = { | ||
's1_raw': np.random.rand(2, height, width).astype(np.float32) * RANGE | ||
- MIN_VALUE, | ||
'DEM': np.random.rand(1, height, width).astype(np.float32) * RANGE - MIN_VALUE, | ||
'mask': np.random.randint(low=0, high=2, size=(1, height, width)).astype( | ||
np.uint8 | ||
), | ||
} | ||
|
||
os.makedirs(os.path.join(path, 'hydro'), exist_ok=True) | ||
|
||
for folder in FOLDERS: | ||
folder_path = os.path.join(path, folder) | ||
os.makedirs(folder_path, exist_ok=True) | ||
filepath = os.path.join(folder_path, filename) | ||
profile2 = profile.copy() | ||
profile2['count'] = 2 if folder == 's1_raw' else 1 | ||
with rasterio.open(filepath, mode='w', **profile2) as src: | ||
src.write(data[folder]) | ||
|
||
return | ||
|
||
|
||
def generate_tar_gz(src: str, dst: str) -> None: | ||
with tarfile.open(dst, 'w:gz') as tar: | ||
tar.add(src, arcname=src) | ||
return | ||
|
||
|
||
def split_tar(path: str, dst: str, nparts: int) -> None: | ||
fstats = os.stat(path) | ||
size = fstats.st_size | ||
chunk = size // nparts | ||
|
||
with open(path, 'rb') as fp: | ||
for idx in range(nparts): | ||
part_path = os.path.join(dst, f'activations.tar.{idx:03}.gz.part') | ||
|
||
bytes_to_write = chunk if idx < nparts - 1 else size - fp.tell() | ||
with open(part_path, 'wb') as dst_fp: | ||
dst_fp.write(fp.read(bytes_to_write)) | ||
|
||
return | ||
|
||
|
||
def generate_folders_and_metadata(datapath: str, metadatapath: str) -> None: | ||
folders_splits = [ | ||
('EMSR000', 'train'), | ||
('EMSR001', 'train'), | ||
('EMSR003', 'val'), | ||
('EMSR004', 'test'), | ||
] | ||
num_files = {'EMSR000': 3, 'EMSR001': 2, 'EMSR003': 2, 'EMSR004': 1} | ||
metadata = {} | ||
for folder, split in folders_splits: | ||
data = {} | ||
data['title'] = 'Test flood' | ||
data['type'] = 'Flood' | ||
data['country'] = 'N/A' | ||
data['start'] = '2014-11-06T17:57:00' | ||
data['end'] = '2015-01-29T12:47:04' | ||
data['lat'] = 45.82427031690563 | ||
data['lon'] = 14.484407562009336 | ||
data['subset'] = split | ||
data['delineations'] = [f'{folder}_00'] | ||
|
||
dst_folder = os.path.join(datapath, f'{folder}-0') | ||
for idx in range(num_files[folder]): | ||
generate_data( | ||
dst_folder, filename=f'{folder}-{idx}.tif', height=16, width=16 | ||
) | ||
|
||
metadata[folder] = data | ||
|
||
generate_tar_gz(src='activations', dst='activations.tar.gz') | ||
split_tar(path='activations.tar.gz', dst='.', nparts=2) | ||
os.remove('activations.tar.gz') | ||
with open(os.path.join(metadatapath, 'activations.json'), 'w') as fp: | ||
json.dump(metadata, fp) | ||
|
||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
datapath = os.path.join(os.getcwd(), 'activations') | ||
metadatapath = os.getcwd() | ||
|
||
generate_folders_and_metadata(datapath, metadatapath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
from itertools import product | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from _pytest.fixtures import SubRequest | ||
from pytest import MonkeyPatch | ||
from rasterio.crs import CRS | ||
|
||
from torchgeo.datasets import ( | ||
BoundingBox, | ||
DatasetNotFoundError, | ||
IntersectionDataset, | ||
MMFlood, | ||
UnionDataset, | ||
) | ||
|
||
|
||
class TestMMFlood: | ||
@pytest.fixture(params=product([True, False], ['train', 'val', 'test'])) | ||
def dataset( | ||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest | ||
) -> MMFlood: | ||
dataset_root = os.path.join('tests', 'data', 'mmflood/') | ||
url = os.path.join(dataset_root) | ||
|
||
monkeypatch.setattr(MMFlood, 'url', url) | ||
monkeypatch.setattr(MMFlood, '_nparts', 2) | ||
|
||
include_dem, split = request.param | ||
root = tmp_path | ||
return MMFlood( | ||
root, | ||
split=split, | ||
include_dem=include_dem, | ||
transforms=nn.Identity(), | ||
download=True, | ||
checksum=True, | ||
) | ||
|
||
def test_getitem(self, dataset: MMFlood) -> None: | ||
x = dataset[dataset.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x['crs'], CRS) | ||
assert isinstance(x['image'], torch.Tensor) | ||
assert isinstance(x['mask'], torch.Tensor) | ||
|
||
# If DEM is included, check if 3 channels are present, 2 otherwise | ||
if dataset.include_dem: | ||
assert x['image'].size(0) == 3 | ||
else: | ||
assert x['image'].size(0) == 2 | ||
return | ||
|
||
def test_len(self, dataset: MMFlood) -> None: | ||
if dataset.split == 'train': | ||
assert len(dataset) == 5 | ||
elif dataset.split == 'val': | ||
assert len(dataset) == 2 | ||
else: | ||
assert len(dataset) == 1 | ||
|
||
def test_and(self, dataset: MMFlood) -> None: | ||
ds = dataset & dataset | ||
assert isinstance(ds, IntersectionDataset) | ||
|
||
def test_or(self, dataset: MMFlood) -> None: | ||
ds = dataset | dataset | ||
assert isinstance(ds, UnionDataset) | ||
|
||
def test_already_downloaded(self, dataset: MMFlood) -> None: | ||
MMFlood(root=dataset.root) | ||
|
||
def test_not_downloaded(self, tmp_path: Path) -> None: | ||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'): | ||
MMFlood(tmp_path) | ||
|
||
def test_plot(self, dataset: MMFlood) -> None: | ||
x = dataset[dataset.bounds] | ||
dataset.plot(x, suptitle='Test') | ||
plt.close() | ||
|
||
def test_plot_prediction(self, dataset: MMFlood) -> None: | ||
x = dataset[dataset.bounds] | ||
x['prediction'] = x['mask'].clone() | ||
dataset.plot(x, suptitle='Prediction') | ||
plt.close() | ||
|
||
def test_invalid_query(self, dataset: MMFlood) -> None: | ||
query = BoundingBox(0, 0, 0, 0, 0, 0) | ||
with pytest.raises( | ||
IndexError, match='query: .* not found in index with bounds:' | ||
): | ||
dataset[query] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The paper is CC-BY-4.0, but the data is MIT, I would use MIT here