Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VHR10DataModule #798

Closed
wants to merge 19 commits into from
Closed
28 changes: 28 additions & 0 deletions conf/vhr10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
program:
seed: 0
overwrite: True

trainer:
gpus: 1
min_epochs: 5
max_epochs: 100
auto_lr_find: False
benchmark: True

experiment:
task: "vhr10"
name: "vhr10_test"
module:
detection_model: "faster-rcnn"
backbone: "resnet50"
pretrained: True
num_classes: 11
learning_rate: 1.3e-5
learning_rate_schedule_patience: 6
verbose: false
datamodule:
root: "data/vhr10"
batch_size: 2
patch_size: 512
num_workers: 56
val_split_pct: 0.2
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- pytorch>=1.9
- rarfile>=3
- rasterio>=1.0.20
- scikit-image>=0.15.0
- shapely>=1.3
- torchvision>=0.10
- pip:
Expand Down
4 changes: 3 additions & 1 deletion requirements/min.old
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pyproj==2.2.0
pytorch-lightning==1.5.1
rasterio==1.0.20
rtree==1.0.0
scikit-image==0.15.0
scikit-learn==0.21.0
segmentation-models-pytorch==0.2.0
shapely==1.3.0
Expand All @@ -28,10 +29,11 @@ laspy==2.0.0
open3d==0.11.2
opencv-python==3.4.2.17
pandas==0.23.2
pycocotools==2.0.0
pycocotools==2.0.1
radiant-mlhub==0.2.1
rarfile==3.0
scipy==1.2.0
scikit-image==0.15.0
zipfile-deflate64==0.2.0

# docs
Expand Down
1 change: 1 addition & 0 deletions requirements/required.old
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pytorch-lightning==1.6.4
rasterio==1.3.0;python_version>='3.8'
rasterio==1.2.10;python_version=='3.7'
rtree==1.0.0
scikit-image>=0.15.0;
scikit-learn==1.1.1;python_version>='3.8'
scikit-learn==1.0.2;python_version=='3.7'
segmentation-models-pytorch==0.2.1
Expand Down
1 change: 1 addition & 0 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pyproj==3.4.0;python_version>='3.8'
pytorch-lightning==1.7.7
rasterio==1.3.2;python_version>='3.8'
rtree==1.0.1
scikit-image>=0.15.0;
scikit-learn==1.1.2;python_version>='3.8'
segmentation-models-pytorch==0.3.0
shapely==1.8.5.post1
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ install_requires =
rtree>=1,<2
# scikit-learn 0.21+ required to fix murmurhash3_32 import bug
scikit-learn>=0.21,<2
# scikit-image required for find_contours
scikit-image>=0.15.0,<0.20
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2,<0.4
# shapely 1.3+ required for Python 3 support
Expand Down
16 changes: 16 additions & 0 deletions tests/conf/vhr10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

experiment:
task: "vhr10"
module:
detection_model: "faster-rcnn"
backbone: "resnet18"
num_classes: 11
learning_rate: 1e-4
learning_rate_schedule_patience: 6
verbose: false
datamodule:
root: "tests/data/vhr10"
seed: 0
batch_size: 1
num_workers: 0
patch_size: 4
Binary file modified tests/data/vhr10/NWPU VHR-10 dataset.rar
Binary file not shown.
135 changes: 1 addition & 134 deletions tests/data/vhr10/annotations.json
Original file line number Diff line number Diff line change
@@ -1,134 +1 @@
{
"info": {
"description": null,
"url": null,
"version": null,
"year": 2021,
"contributor": null,
"date_created": "2021-01-01 00:00:00"
},
"licenses": [
{
"url": null,
"id": 0,
"name": null
}
],
"images": [
{
"license": 0,
"url": null,
"file_name": "001.jpg",
"height": 1,
"width": 1,
"date_captured": null,
"id": 0
},
{
"license": 0,
"url": null,
"file_name": "002.jpg",
"height": 1,
"width": 1,
"date_captured": null,
"id": 1
}
],
"type": "instances",
"annotations": [
{
"id": 0,
"image_id": 0,
"category_id": 1,
"segmentation": [
[
1,
2,
3,
4
]
],
"area": 1.0,
"bbox": [
1,
2,
3,
4
],
"iscrowd": 0
},
{
"id": 1,
"image_id": 1,
"category_id": 1,
"segmentation": [
[
1,
2,
3,
4
]
],
"area": 1.0,
"bbox": [
1,
2,
3,
4
],
"iscrowd": 0
}
],
"categories": [
{
"supercategory": null,
"id": 1,
"name": "airplane"
},
{
"supercategory": null,
"id": 2,
"name": "ship"
},
{
"supercategory": null,
"id": 3,
"name": "storage_tank"
},
{
"supercategory": null,
"id": 4,
"name": "baseball_diamond"
},
{
"supercategory": null,
"id": 5,
"name": "tennis_court"
},
{
"supercategory": null,
"id": 6,
"name": "basketball_court"
},
{
"supercategory": null,
"id": 7,
"name": "ground_track_field"
},
{
"supercategory": null,
"id": 8,
"name": "harbor"
},
{
"supercategory": null,
"id": 9,
"name": "bridge"
},
{
"supercategory": null,
"id": 10,
"name": "vehicle"
}
]
}
{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]}
104 changes: 104 additions & 0 deletions tests/data/vhr10/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json
import os
import shutil
import subprocess
import warnings
from copy import deepcopy

import numpy as np
import rasterio as rio
from rasterio.errors import NotGeoreferencedWarning
from torchvision.datasets.utils import calculate_md5

ANNOTATION_FILE = {"images": [], "annotations": []}


def write_data(path: str, img: np.ndarray) -> None:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=NotGeoreferencedWarning)
with rio.open(
path,
"w",
driver="JP2OpenJPEG",
height=img.shape[0],
width=img.shape[1],
count=3,
dtype=img.dtype,
) as dst:
for i in range(1, dst.count + 1):
dst.write(img, i)


def generate_test_data(root: str, n_imgs: int = 3) -> str:
folder_path = os.path.join(root, "NWPU VHR-10 dataset")
pos_img_dir = os.path.join(folder_path, "positive image set")
neg_img_dir = os.path.join(folder_path, "negative image set")
ann_file = os.path.join(folder_path, "annotations.json")
ann_file2 = os.path.join(root, "annotations.json")

if not os.path.exists(pos_img_dir):
os.makedirs(pos_img_dir)
if not os.path.exists(neg_img_dir):
os.makedirs(neg_img_dir)

for img_id in range(1, n_imgs + 1):
pos_img_name = os.path.join(pos_img_dir, f"00{img_id}.jpg")
neg_img_name = os.path.join(neg_img_dir, f"00{img_id}.jpg")

img = np.random.randint(255, size=(8, 8), dtype=np.dtype("uint8"))
write_data(pos_img_name, img)
write_data(neg_img_name, img)

img_name = os.path.basename(pos_img_name)

ANNOTATION_FILE["images"].append(
{"file_name": img_name, "height": 8, "width": 8, "id": img_id - 1}
)

ann = 0
import pdb

pdb.set_trace()
for i, img in enumerate(ANNOTATION_FILE["images"]):
annot = {
"id": ann,
"image_id": img["id"],
"category_id": 1,
"area": 4.0,
"bbox": [4, 4, 2, 2],
"segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]],
"iscrowd": 0,
}
if i != 0:
ANNOTATION_FILE["annotations"].append(annot)
else:
noseg_annot = deepcopy(annot)
del noseg_annot["segmentation"]
ANNOTATION_FILE["annotations"].append(noseg_annot)
ann += 1
import pdb

pdb.set_trace()
with open(ann_file, "w") as j:
json.dump(ANNOTATION_FILE, j)

with open(ann_file2, "w") as j:
json.dump(ANNOTATION_FILE, j)

# Create rar file
subprocess.run(
["rar", "a", "NWPU VHR-10 dataset.rar", "-m5", "NWPU VHR-10 dataset"],
capture_output=True,
check=True,
)

annotations_md5 = calculate_md5(ann_file)
archive_md5 = calculate_md5("NWPU VHR-10 dataset.rar")
shutil.rmtree(folder_path)

return f"archive md5: {archive_md5}, annotation md5: {annotations_md5}"


if __name__ == "__main__":
md5 = generate_test_data(os.getcwd(), 5)
print(md5)
37 changes: 37 additions & 0 deletions tests/datamodules/test_vhr10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest

from torchgeo.datamodules import VHR10DataModule


class TestVHR10DataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> VHR10DataModule:
root = os.path.join("tests", "data", "vhr10")
batch_size = 1
num_workers = 0
val_split_pct = 0.4
test_split_pct = 0.2
dm = VHR10DataModule(
root=root,
batch_size=batch_size,
num_workers=num_workers,
val_split_pct=val_split_pct,
test_split_pct=test_split_pct,
)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: VHR10DataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: VHR10DataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: VHR10DataModule) -> None:
next(iter(datamodule.test_dataloader()))
Loading