Skip to content

Commit

Permalink
Add SpaceNet6 datamodule (#2367)
Browse files Browse the repository at this point in the history
* spacenet 6 datamodule

* datamodule base class

* mypy

* fix tests

* fix docs

* fix tests again

* uncomment test

* magic comma

* review

* class -> instance

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
nilsleh and adamjstewart authored Oct 27, 2024
1 parent 52fb6e3 commit c746cac
Show file tree
Hide file tree
Showing 68 changed files with 521 additions and 153 deletions.
2 changes: 2 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ So2Sat
SpaceNet
^^^^^^^^

.. autoclass:: SpaceNetBaseDataModule
.. autoclass:: SpaceNet1DataModule
.. autoclass:: SpaceNet6DataModule

SSL4EO
^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/spacenet1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ data:
val_split_pct: 0.34
test_split_pct: 0.34
dict_kwargs:
root: 'tests/data/spacenet'
root: 'tests/data/spacenet/spacenet1'
19 changes: 19 additions & 0 deletions tests/conf/spacenet6.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 4
num_classes: 3
num_filters: 1
ignore_index: null
data:
class_path: SpaceNet6DataModule
init_args:
batch_size: 1
val_split_pct: 0.34
test_split_pct: 0.34
dict_kwargs:
root: 'tests/data/spacenet/spacenet6'
image: 'SAR-Intensity'
Binary file not shown.
Binary file not shown.
Binary file not shown.
92 changes: 0 additions & 92 deletions tests/data/spacenet/data.py

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [[[-43.7720361, -22.922229499999958, 0.0], [-43.772064, -22.9222724, 0.0], [-43.77210239999994, -22.922247399999947, 0.0], [-43.772074499999974, -22.9222046, 0.0], [-43.7720361, -22.922229499999958, 0.0]]]}}]}
161 changes: 161 additions & 0 deletions tests/data/spacenet/spacenet1/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import json
import os
import shutil
from typing import Any

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine

SIZE = 2

NUM_SAMPLES = 4

dataset_id = 'SN1_buildings'

profile = {
'driver': 'GTiff',
'dtype': 'uint8',
'width': SIZE,
'height': SIZE,
'crs': CRS.from_epsg(4326),
'transform': Affine(
4.489235388119662e-06,
0.0,
-43.7732462563,
0.0,
-4.486127586210932e-06,
-22.9214851954,
),
}

np.random.seed(0)
Z = np.random.randint(np.iinfo('uint8').max, size=(SIZE, SIZE), dtype='uint8')


def create_directories(base_path: str, band_counts: list[int]) -> None:
for count in band_counts:
os.makedirs(os.path.join(base_path, f'{count}band'), exist_ok=True)


def generate_geotiff_files(
base_path: str, band_counts: list[int], profile: dict[str, Any], Z: np.ndarray
) -> None:
for count in band_counts:
for i in range(1, NUM_SAMPLES + 1):
path = os.path.join(
base_path, f'{count}band', f'{count}band_AOI_1_RIO_img{i}.tif'
)
profile['count'] = count
with rasterio.open(path, 'w', **profile) as src:
for j in range(1, count + 1):
src.write(Z, j)


def generate_geojson_files(base_path: str, geojson: dict[str, Any]) -> None:
os.makedirs(os.path.join(base_path, 'geojson'), exist_ok=True)
for i in range(1, NUM_SAMPLES + 1):
path = os.path.join(base_path, 'geojson', f'Geo_AOI_1_RIO_img{i}.geojson')
with open(path, 'w') as src:
if i % 2 == 0:
json.dump(geojson, src)


def compute_md5(file_path: str) -> str:
hash_md5 = hashlib.md5()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_md5.update(chunk)
return hash_md5.hexdigest()


# Generate dummy GeoJSON files for building footprints
geojson = {
'type': 'FeatureCollection',
'crs': {'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}},
'features': [
{
'type': 'Feature',
'geometry': {
'type': 'Polygon',
'coordinates': [
[
[-43.7720361, -22.922229499999958, 0.0],
[-43.772064, -22.9222724, 0.0],
[-43.772102399999937, -22.922247399999947, 0.0],
[-43.772074499999974, -22.9222046, 0.0],
[-43.7720361, -22.922229499999958, 0.0],
]
],
},
}
],
}

# Remove existing data if it exists
if os.path.exists(dataset_id):
shutil.rmtree(dataset_id)

train_base_path = os.path.join(dataset_id, 'train')
test_base_path = os.path.join(dataset_id, 'test')

# Create directories and generate dummy GeoTIFF files for train dataset
create_directories(train_base_path, [3, 8])
generate_geotiff_files(train_base_path, [3, 8], profile, Z)
generate_geojson_files(train_base_path, geojson)

# Create directories and generate dummy GeoTIFF files for test dataset (only 3band and 8band)
create_directories(test_base_path, [3, 8])
generate_geotiff_files(test_base_path, [3, 8], profile, Z)

# Create tarballs for train and test datasets
tarball_specs = {
'train': {
'3band': 'SN1_buildings_train_AOI_1_Rio_3band',
'8band': 'SN1_buildings_train_AOI_1_Rio_8band',
'geojson': 'SN1_buildings_train_AOI_1_Rio_geojson_buildings',
},
'test': {
'3band': 'SN1_buildings_test_AOI_1_Rio_3band',
'8band': 'SN1_buildings_test_AOI_1_Rio_8band',
},
}

for split, specs in tarball_specs.items():
for subdir, tarball_name in specs.items():
tarball_path = os.path.join(dataset_id, split, tarball_name)
shutil.make_archive(
tarball_path,
'gztar',
root_dir=os.path.join(dataset_id, split),
base_dir=subdir,
)

# Compute and print MD5 checksums for the generated tarballs
print('MD5 Checksums for Train Dataset:')
train_tarballs = [
'SN1_buildings_train_AOI_1_Rio_3band.tar.gz',
'SN1_buildings_train_AOI_1_Rio_8band.tar.gz',
'SN1_buildings_train_AOI_1_Rio_geojson_buildings.tar.gz',
]
for tarball in train_tarballs:
tarball_path = os.path.join(dataset_id, 'train', tarball)
if os.path.exists(tarball_path):
print(f'{tarball}: {compute_md5(tarball_path)}')

print('\nMD5 Checksums for Test Dataset:')
test_tarballs = [
'SN1_buildings_test_AOI_1_Rio_3band.tar.gz',
'SN1_buildings_test_AOI_1_Rio_8band.tar.gz',
]
for tarball in test_tarballs:
tarball_path = os.path.join(dataset_id, 'test', tarball)
if os.path.exists(tarball_path):
print(f'{tarball}: {compute_md5(tarball_path)}')
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [[[4.47917, 51.9225, 0.0], [4.4792, 51.92255, 0.0], [4.47925, 51.92252, 0.0], [4.47922, 51.92247, 0.0], [4.47917, 51.9225, 0.0]]]}}]}
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [[[4.47917, 51.9225, 0.0], [4.4792, 51.92255, 0.0], [4.47925, 51.92252, 0.0], [4.47922, 51.92247, 0.0], [4.47917, 51.9225, 0.0]]]}}]}
Loading

0 comments on commit c746cac

Please sign in to comment.