diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml new file mode 100644 index 00000000000..71a7cbaa43b --- /dev/null +++ b/conf/vhr10.yaml @@ -0,0 +1,28 @@ +program: + seed: 0 + overwrite: True + +trainer: + gpus: 1 + min_epochs: 5 + max_epochs: 100 + auto_lr_find: False + benchmark: True + +experiment: + task: "vhr10" + name: "vhr10_test" + module: + detection_model: "faster-rcnn" + backbone: "resnet50" + pretrained: True + num_classes: 11 + learning_rate: 1.3e-5 + learning_rate_schedule_patience: 6 + verbose: false + datamodule: + root: "data/vhr10" + batch_size: 2 + patch_size: 512 + num_workers: 56 + val_split_pct: 0.2 diff --git a/environment.yml b/environment.yml index dafd0778a56..c82849aaa35 100644 --- a/environment.yml +++ b/environment.yml @@ -16,6 +16,7 @@ dependencies: - pytorch>=1.9 - rarfile>=3 - rasterio>=1.0.20 + - scikit-image>=0.15.0 - shapely>=1.3 - torchvision>=0.10 - pip: diff --git a/requirements/min.old b/requirements/min.old index 54d85ac290a..cad884bfe9a 100644 --- a/requirements/min.old +++ b/requirements/min.old @@ -14,6 +14,7 @@ pyproj==2.2.0 pytorch-lightning==1.5.1 rasterio==1.0.20 rtree==1.0.0 +scikit-image==0.15.0 scikit-learn==0.21.0 segmentation-models-pytorch==0.2.0 shapely==1.3.0 @@ -28,10 +29,11 @@ laspy==2.0.0 open3d==0.11.2 opencv-python==3.4.2.17 pandas==0.23.2 -pycocotools==2.0.0 +pycocotools==2.0.1 radiant-mlhub==0.2.1 rarfile==3.0 scipy==1.2.0 +scikit-image==0.15.0 zipfile-deflate64==0.2.0 # docs diff --git a/requirements/required.old b/requirements/required.old index 1e120e0948c..1bb2b38d75d 100644 --- a/requirements/required.old +++ b/requirements/required.old @@ -17,6 +17,7 @@ pytorch-lightning==1.6.4 rasterio==1.3.0;python_version>='3.8' rasterio==1.2.10;python_version=='3.7' rtree==1.0.0 +scikit-image>=0.15.0; scikit-learn==1.1.1;python_version>='3.8' scikit-learn==1.0.2;python_version=='3.7' segmentation-models-pytorch==0.2.1 diff --git a/requirements/required.txt b/requirements/required.txt index 0400203b78a..604eb667c6c 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -14,6 +14,7 @@ pyproj==3.4.0;python_version>='3.8' pytorch-lightning==1.7.7 rasterio==1.3.2;python_version>='3.8' rtree==1.0.1 +scikit-image>=0.15.0; scikit-learn==1.1.2;python_version>='3.8' segmentation-models-pytorch==0.3.0 shapely==1.8.5.post1 diff --git a/setup.cfg b/setup.cfg index 58559d5f52e..50a237da810 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,6 +51,8 @@ install_requires = rtree>=1,<2 # scikit-learn 0.21+ required to fix murmurhash3_32 import bug scikit-learn>=0.21,<2 + # scikit-image required for find_contours + scikit-image>=0.15.0,<0.20 # segmentation-models-pytorch 0.2+ required for smp.losses module segmentation-models-pytorch>=0.2,<0.4 # shapely 1.3+ required for Python 3 support diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml new file mode 100644 index 00000000000..ad432e047b9 --- /dev/null +++ b/tests/conf/vhr10.yaml @@ -0,0 +1,16 @@ + +experiment: + task: "vhr10" + module: + detection_model: "faster-rcnn" + backbone: "resnet18" + num_classes: 11 + learning_rate: 1e-4 + learning_rate_schedule_patience: 6 + verbose: false + datamodule: + root: "tests/data/vhr10" + seed: 0 + batch_size: 1 + num_workers: 0 + patch_size: 4 diff --git a/tests/data/vhr10/NWPU VHR-10 dataset.rar b/tests/data/vhr10/NWPU VHR-10 dataset.rar index 5fc4953cef1..3dc735dd4b6 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset.rar and b/tests/data/vhr10/NWPU VHR-10 dataset.rar differ diff --git a/tests/data/vhr10/annotations.json b/tests/data/vhr10/annotations.json index 406ddc21a9b..60de0cb14f3 100644 --- a/tests/data/vhr10/annotations.json +++ b/tests/data/vhr10/annotations.json @@ -1,134 +1 @@ -{ - "info": { - "description": null, - "url": null, - "version": null, - "year": 2021, - "contributor": null, - "date_created": "2021-01-01 00:00:00" - }, - "licenses": [ - { - "url": null, - "id": 0, - "name": null - } - ], - "images": [ - { - "license": 0, - "url": null, - "file_name": "001.jpg", - "height": 1, - "width": 1, - "date_captured": null, - "id": 0 - }, - { - "license": 0, - "url": null, - "file_name": "002.jpg", - "height": 1, - "width": 1, - "date_captured": null, - "id": 1 - } - ], - "type": "instances", - "annotations": [ - { - "id": 0, - "image_id": 0, - "category_id": 1, - "segmentation": [ - [ - 1, - 2, - 3, - 4 - ] - ], - "area": 1.0, - "bbox": [ - 1, - 2, - 3, - 4 - ], - "iscrowd": 0 - }, - { - "id": 1, - "image_id": 1, - "category_id": 1, - "segmentation": [ - [ - 1, - 2, - 3, - 4 - ] - ], - "area": 1.0, - "bbox": [ - 1, - 2, - 3, - 4 - ], - "iscrowd": 0 - } - ], - "categories": [ - { - "supercategory": null, - "id": 1, - "name": "airplane" - }, - { - "supercategory": null, - "id": 2, - "name": "ship" - }, - { - "supercategory": null, - "id": 3, - "name": "storage_tank" - }, - { - "supercategory": null, - "id": 4, - "name": "baseball_diamond" - }, - { - "supercategory": null, - "id": 5, - "name": "tennis_court" - }, - { - "supercategory": null, - "id": 6, - "name": "basketball_court" - }, - { - "supercategory": null, - "id": 7, - "name": "ground_track_field" - }, - { - "supercategory": null, - "id": 8, - "name": "harbor" - }, - { - "supercategory": null, - "id": 9, - "name": "bridge" - }, - { - "supercategory": null, - "id": 10, - "name": "vehicle" - } - ] -} +{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]} \ No newline at end of file diff --git a/tests/data/vhr10/data.py b/tests/data/vhr10/data.py new file mode 100644 index 00000000000..3be22226954 --- /dev/null +++ b/tests/data/vhr10/data.py @@ -0,0 +1,104 @@ +import json +import os +import shutil +import subprocess +import warnings +from copy import deepcopy + +import numpy as np +import rasterio as rio +from rasterio.errors import NotGeoreferencedWarning +from torchvision.datasets.utils import calculate_md5 + +ANNOTATION_FILE = {"images": [], "annotations": []} + + +def write_data(path: str, img: np.ndarray) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=NotGeoreferencedWarning) + with rio.open( + path, + "w", + driver="JP2OpenJPEG", + height=img.shape[0], + width=img.shape[1], + count=3, + dtype=img.dtype, + ) as dst: + for i in range(1, dst.count + 1): + dst.write(img, i) + + +def generate_test_data(root: str, n_imgs: int = 3) -> str: + folder_path = os.path.join(root, "NWPU VHR-10 dataset") + pos_img_dir = os.path.join(folder_path, "positive image set") + neg_img_dir = os.path.join(folder_path, "negative image set") + ann_file = os.path.join(folder_path, "annotations.json") + ann_file2 = os.path.join(root, "annotations.json") + + if not os.path.exists(pos_img_dir): + os.makedirs(pos_img_dir) + if not os.path.exists(neg_img_dir): + os.makedirs(neg_img_dir) + + for img_id in range(1, n_imgs + 1): + pos_img_name = os.path.join(pos_img_dir, f"00{img_id}.jpg") + neg_img_name = os.path.join(neg_img_dir, f"00{img_id}.jpg") + + img = np.random.randint(255, size=(8, 8), dtype=np.dtype("uint8")) + write_data(pos_img_name, img) + write_data(neg_img_name, img) + + img_name = os.path.basename(pos_img_name) + + ANNOTATION_FILE["images"].append( + {"file_name": img_name, "height": 8, "width": 8, "id": img_id - 1} + ) + + ann = 0 + import pdb + + pdb.set_trace() + for i, img in enumerate(ANNOTATION_FILE["images"]): + annot = { + "id": ann, + "image_id": img["id"], + "category_id": 1, + "area": 4.0, + "bbox": [4, 4, 2, 2], + "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], + "iscrowd": 0, + } + if i != 0: + ANNOTATION_FILE["annotations"].append(annot) + else: + noseg_annot = deepcopy(annot) + del noseg_annot["segmentation"] + ANNOTATION_FILE["annotations"].append(noseg_annot) + ann += 1 + import pdb + + pdb.set_trace() + with open(ann_file, "w") as j: + json.dump(ANNOTATION_FILE, j) + + with open(ann_file2, "w") as j: + json.dump(ANNOTATION_FILE, j) + + # Create rar file + subprocess.run( + ["rar", "a", "NWPU VHR-10 dataset.rar", "-m5", "NWPU VHR-10 dataset"], + capture_output=True, + check=True, + ) + + annotations_md5 = calculate_md5(ann_file) + archive_md5 = calculate_md5("NWPU VHR-10 dataset.rar") + shutil.rmtree(folder_path) + + return f"archive md5: {archive_md5}, annotation md5: {annotations_md5}" + + +if __name__ == "__main__": + md5 = generate_test_data(os.getcwd(), 5) + print(md5) diff --git a/tests/datamodules/test_vhr10.py b/tests/datamodules/test_vhr10.py new file mode 100644 index 00000000000..3fbb9a6bd29 --- /dev/null +++ b/tests/datamodules/test_vhr10.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest + +from torchgeo.datamodules import VHR10DataModule + + +class TestVHR10DataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> VHR10DataModule: + root = os.path.join("tests", "data", "vhr10") + batch_size = 1 + num_workers = 0 + val_split_pct = 0.4 + test_split_pct = 0.2 + dm = VHR10DataModule( + root=root, + batch_size=batch_size, + num_workers=num_workers, + val_split_pct=val_split_pct, + test_split_pct=test_split_pct, + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: VHR10DataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: VHR10DataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: VHR10DataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/datasets/test_nwpu.py b/tests/datasets/test_vhr10.py similarity index 68% rename from tests/datasets/test_nwpu.py rename to tests/datasets/test_vhr10.py index 582b6730908..d2400ee706c 100644 --- a/tests/datasets/test_nwpu.py +++ b/tests/datasets/test_vhr10.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any +import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn @@ -30,15 +31,15 @@ def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> VHR10: pytest.importorskip("rarfile", minversion="3") - monkeypatch.setattr(torchgeo.datasets.nwpu, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.vhr10, "download_url", download_url) monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) url = os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar") monkeypatch.setitem(VHR10.image_meta, "url", url) - md5 = "e5c38351bd948479fe35a71136aedbc4" + md5 = "e68727b2c91ac849def1985924d9586f" monkeypatch.setitem(VHR10.image_meta, "md5", md5) url = os.path.join("tests", "data", "vhr10", "annotations.json") monkeypatch.setitem(VHR10.target_meta, "url", url) - md5 = "16fc6aa597a19179dad84151cc221873" + md5 = "833899cce369168e0d4ee420dac326dc" monkeypatch.setitem(VHR10.target_meta, "md5", md5) root = str(tmp_path) split = request.param @@ -57,14 +58,19 @@ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: monkeypatch.setattr(builtins, "__import__", mocked_import) def test_getitem(self, dataset: VHR10) -> None: - x = dataset[0] - assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["label"], dict) + for i in range(2): + x = dataset[i] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + if dataset.split == "positive": + assert isinstance(x["labels"], torch.Tensor) + assert isinstance(x["boxes"], torch.Tensor) + if "masks" in x: + assert isinstance(x["masks"], torch.Tensor) def test_len(self, dataset: VHR10) -> None: if dataset.split == "positive": - assert len(dataset) == 650 + assert len(dataset) == len(dataset.ids) elif dataset.split == "negative": assert len(dataset) == 150 @@ -72,7 +78,7 @@ def test_add(self, dataset: VHR10) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) if dataset.split == "positive": - assert len(ds) == 1300 + assert len(ds) == 10 elif dataset.split == "negative": assert len(ds) == 300 @@ -96,3 +102,21 @@ def test_mock_missing_module( match="pycocotools is not installed and is required to use this datase", ): VHR10(dataset.root, dataset.split) + + def test_plot(self, dataset: VHR10) -> None: + x = dataset[1].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + if dataset.split == "positive": + scores = [0.7, 0.3, 0.7] + for i in range(3): + x = dataset[i] + x["prediction_labels"] = x["labels"] + x["prediction_boxes"] = x["boxes"] + x["prediction_scores"] = torch.Tensor([scores[i]]) + if "masks" in x: + x["prediction_masks"] = x["masks"] + dataset.plot(x, show_feats="masks") + plt.close() diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 0f2d8a65cc6..f2d9e78eb22 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -9,13 +9,17 @@ from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -from torchgeo.datamodules import NASAMarineDebrisDataModule +from torchgeo.datamodules import NASAMarineDebrisDataModule, VHR10DataModule from torchgeo.trainers import ObjectDetectionTask class TestObjectDetectionTask: @pytest.mark.parametrize( - "name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)] + "name,classname", + [ + ("nasa_marine_debris", NASAMarineDebrisDataModule), + ("vhr10", VHR10DataModule), + ], ) def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml")) diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 2230cf00eaf..72e44adf836 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -24,6 +24,7 @@ from .ucmerced import UCMercedDataModule from .usavars import USAVarsDataModule from .vaihingen import Vaihingen2DDataModule +from .vhr10 import VHR10DataModule from .xview import XView2DataModule __all__ = ( @@ -50,6 +51,7 @@ "UCMercedDataModule", "USAVarsDataModule", "Vaihingen2DDataModule", + "VHR10DataModule", "XView2DataModule", ) diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py new file mode 100644 index 00000000000..2e0e3eb35c4 --- /dev/null +++ b/torchgeo/datamodules/vhr10.py @@ -0,0 +1,170 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NWPU VHR-10 datamodule.""" + +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +import matplotlib.pyplot as plt +import pytorch_lightning as pl +import torch +import torchvision +from torch import Tensor +from torch.utils.data import DataLoader + +from ..datasets import VHR10 +from ..samplers.utils import _to_tuple +from .utils import dataset_split + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: + """Custom object detection collate fn to handle variable boxes. + + Args: + batch: list of sample dicts return by dataset + + Returns: + batch dict output + """ + output: Dict[str, Any] = {} + output["image"] = torch.stack([sample["image"] for sample in batch]) + output["labels"] = [sample["labels"] for sample in batch] + output["boxes"] = [sample["boxes"] for sample in batch] + if "masks" in batch[0]: + output["masks"] = [sample["masks"] for sample in batch] + return output + + +class VHR10DataModule(pl.LightningDataModule): + """LightningDataModule implementation for the VHR10 dataset. + + .. versionadded:: 0.4 + """ + + def __init__( + self, + root: str, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + patch_size: Union[int, Tuple[int, int]] = 512, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for VHR10 based DataLoaders. + + Args: + root: The ``root`` argument to pass to the Dataset class + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + patch_size: Patch size (height, width) for batched training + """ + super().__init__() + self.root = root + self.batch_size = batch_size + self.num_workers = num_workers + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size)) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + + _, h, w = sample["image"].shape + sample["image"] = torchvision.transforms.functional.resize( + sample["image"], size=self.patch_size + ) + box_scale = (self.patch_size[1] / w, self.patch_size[0] / h) + sample["boxes"][:, 0] *= box_scale[0] + sample["boxes"][:, 1] *= box_scale[1] + sample["boxes"][:, 2] *= box_scale[0] + sample["boxes"][:, 3] *= box_scale[1] + sample["boxes"] = torch.round(sample["boxes"]) + + if "masks" in sample: + sample["masks"] = torchvision.transforms.functional.resize( + sample["masks"], size=self.patch_size + ) + + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + VHR10(self.root, download=True, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + + Args: + stage: stage to set up + """ + self.dataset = VHR10(self.root, transforms=self.preprocess) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training. + + Returns: + training data loader + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=collate_fn, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + collate_fn=collate_fn, + ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.VHR10.plot`.""" + return self.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 6b6aa6fbc6e..a0a9ea0db1a 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -74,7 +74,6 @@ from .millionaid import MillionAID from .naip import NAIP from .nasa_marine_debris import NASAMarineDebris -from .nwpu import VHR10 from .openbuildings import OpenBuildings from .oscd import OSCD from .patternnet import PatternNet @@ -104,6 +103,7 @@ unbind_samples, ) from .vaihingen import Vaihingen2D +from .vhr10 import VHR10 from .xview import XView2 from .zuericrop import ZueriCrop diff --git a/torchgeo/datasets/nwpu.py b/torchgeo/datasets/nwpu.py deleted file mode 100644 index 12bada1a674..00000000000 --- a/torchgeo/datasets/nwpu.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""NWPU VHR-10 dataset.""" - -import os -from typing import Any, Callable, Dict, Optional - -import numpy as np -import torch -from PIL import Image -from torch import Tensor - -from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive, download_url - - -class VHR10(NonGeoDataset): - """NWPU VHR-10 dataset. - - Northwestern Polytechnical University (NWPU) very-high-resolution ten-class (VHR-10) - remote sensing image dataset. - - Consists of 800 VHR optical remote sensing images, where 715 color images were - acquired from Google Earth with the spatial resolution ranging from 0.5 to 2 m, - and 85 pansharpened color infrared (CIR) images were acquired from Vaihingen data - with a spatial resolution of 0.08 m. - - The data set is divided into two sets: - - * Positive image set (650 images) which contains at least one target in an image - * Negative image set (150 images) does not contain any targets - - The positive image set consists of objects from ten classes: - - 1. Airplanes (757) - 2. Ships (302) - 3. Storage tanks (655) - 4. Baseball diamonds (390) - 5. Tennis courts (524) - 6. Basketball courts (159) - 7. Ground track fields (163) - 8. Harbors (224) - 9. Bridges (124) - 10. Vehicles (477) - - Includes object detection bounding boxes from original paper and instance - segmentation masks from follow-up publications. If you use this dataset in your - research, please cite the following papers: - - * https://doi.org/10.1016/j.isprsjprs.2014.10.002 - * https://doi.org/10.1109/IGARSS.2019.8898573 - * https://doi.org/10.3390/rs12060989 - - .. note:: - - This dataset requires the following additional libraries to be installed: - - * `pycocotools `_ to load the - ``annotations.json`` file for the "positive" image set - * `rarfile `_ to extract the dataset, - which is stored in a RAR file - """ - - image_meta = { - "url": "https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE", - "filename": "NWPU VHR-10 dataset.rar", - "md5": "d30a7ff99d92123ebb0b3a14d9102081", - } - target_meta = { - "url": ( - "https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/" - "master/NWPU%20VHR-10_dataset_coco/annotations.json" - ), - "filename": "annotations.json", - "md5": "7c76ec50c17a61bb0514050d20f22c08", - } - - def __init__( - self, - root: str = "data", - split: str = "positive", - transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, - download: bool = False, - checksum: bool = False, - ) -> None: - """Initialize a new VHR-10 dataset instance. - - Args: - root: root directory where dataset can be found - split: one of "postive" or "negative" - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version - download: if True, download dataset and store it in the root directory - checksum: if True, check the MD5 of the downloaded files (may be slow) - - Raises: - AssertionError: if ``split`` argument is invalid - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match - """ - assert split in ["positive", "negative"] - - self.root = root - self.split = split - self.transforms = transforms - self.checksum = checksum - - if download: - self._download() - - if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. " - + "You can use download=True to download it" - ) - - if split == "positive": - # Must be installed to parse annotations file - try: - from pycocotools.coco import COCO # noqa: F401 - except ImportError: - raise ImportError( - "pycocotools is not installed and is required to use this dataset" - ) - - self.coco = COCO( - os.path.join( - self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] - ) - ) - - def __getitem__(self, index: int) -> Dict[str, Any]: - """Return an index within the dataset. - - Args: - index: index to return - - Returns: - data and label at that index - """ - id_ = index % len(self) + 1 - sample = {"image": self._load_image(id_), "label": self._load_target(id_)} - - if self.transforms is not None: - sample = self.transforms(sample) - - return sample - - def __len__(self) -> int: - """Return the number of data points in the dataset. - - Returns: - length of the dataset - """ - if self.split == "positive": - return 650 - else: - return 150 - - def _load_image(self, id_: int) -> Tensor: - """Load a single image. - - Args: - id_: unique ID of the image - - Returns: - the image - """ - filename = os.path.join( - self.root, - "NWPU VHR-10 dataset", - self.split + " image set", - f"{id_:03d}.jpg", - ) - with Image.open(filename) as img: - array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) - # Convert from HxWxC to CxHxW - tensor = tensor.permute((2, 0, 1)) - return tensor - - def _load_target(self, id_: int) -> Dict[str, Any]: - """Load the annotations for a single image. - - Args: - id_: unique ID of the image - - Returns: - the annotations - """ - # Images in the "negative" image set have no annotations - annot = [] - if self.split == "positive": - annot = self.coco.loadAnns(self.coco.getAnnIds(id_)) - - target = dict(image_id=id_, annotations=annot) - - return target - - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - image: bool = check_integrity( - os.path.join(self.root, self.image_meta["filename"]), - self.image_meta["md5"] if self.checksum else None, - ) - - # Annotations only needed for "positive" image set - target = True - if self.split == "positive": - target = check_integrity( - os.path.join( - self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] - ), - self.target_meta["md5"] if self.checksum else None, - ) - - return image and target - - def _download(self) -> None: - """Download the dataset and extract it.""" - if self._check_integrity(): - print("Files already downloaded and verified") - return - - # Download images - download_and_extract_archive( - self.image_meta["url"], - self.root, - filename=self.image_meta["filename"], - md5=self.image_meta["md5"] if self.checksum else None, - ) - - # Annotations only needed for "positive" image set - if self.split == "positive": - # Download annotations - download_url( - self.target_meta["url"], - os.path.join(self.root, "NWPU VHR-10 dataset"), - self.target_meta["filename"], - self.target_meta["md5"] if self.checksum else None, - ) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py new file mode 100644 index 00000000000..ade678ddbf5 --- /dev/null +++ b/torchgeo/datasets/vhr10.py @@ -0,0 +1,540 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""NWPU VHR-10 dataset.""" + +import os +from typing import Any, Callable, Dict, List, Optional + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib import patches +from PIL import Image +from skimage.measure import find_contours +from torch import Tensor + +from .geo import NonGeoDataset +from .utils import check_integrity, download_and_extract_archive, download_url + + +def convert_coco_poly_to_mask( + segmentations: List[int], height: int, width: int +) -> Tensor: + """Convert coco polygons to mask tensor. + + Args: + segmentations (List[int]): polygon coordinates + height (int): image height + width (int): image width + + Returns: + Tensor: Mask tensor + """ + from pycocotools import mask as coco_mask # noqa: F401 + + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks_tensor = torch.stack(masks, dim=0) + else: + masks_tensor = torch.zeros((0, height, width), dtype=torch.uint8) + return masks_tensor + + +class ConvertCocoAnnotations: + """Callable for converting the boxes, masks and labels into tensors. + + This is a modified version of ConvertCocoPolysToMask() from torchvision found in + https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py + """ + + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Converts MS COCO fields (boxes, masks & labels) from list of ints to tensors. + + Args: + sample (Dict[str, Any]): Sample + + Returns: + Dict[str, Any]: Processed sample + """ + image = sample["image"] + _, h, w = image.size() + target = sample["label"] + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + + anno = [obj for obj in anno if obj["iscrowd"] == 0] + + bboxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(bboxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + categories = [obj["category_id"] for obj in anno] + classes = torch.tensor(categories, dtype=torch.int64) + + if "segmentation" in anno[0]: + segmentations = [obj["segmentation"] for obj in anno] + else: + segmentations = [] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + + target = {"boxes": boxes, "labels": classes, "image_id": image_id} + if masks.nelement() > 0: + masks = masks[keep] + target["masks"] = masks + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) + target["area"] = area + target["iscrowd"] = iscrowd + return {"image": image, "label": target} + + +class VHR10(NonGeoDataset): + """NWPU VHR-10 dataset. + + Northwestern Polytechnical University (NWPU) very-high-resolution ten-class (VHR-10) + remote sensing image dataset. + + Consists of 800 VHR optical remote sensing images, where 715 color images were + acquired from Google Earth with the spatial resolution ranging from 0.5 to 2 m, + and 85 pansharpened color infrared (CIR) images were acquired from Vaihingen data + with a spatial resolution of 0.08 m. + + The data set is divided into two sets: + + * Positive image set (650 images) which contains at least one target in an image + * Negative image set (150 images) does not contain any targets + + The positive image set consists of objects from ten classes: + + 1. Airplanes (757) + 2. Ships (302) + 3. Storage tanks (655) + 4. Baseball diamonds (390) + 5. Tennis courts (524) + 6. Basketball courts (159) + 7. Ground track fields (163) + 8. Harbors (224) + 9. Bridges (124) + 10. Vehicles (477) + + Includes object detection bounding boxes from original paper and instance + segmentation masks from follow-up publications. If you use this dataset in your + research, please cite the following papers: + + * https://doi.org/10.1016/j.isprsjprs.2014.10.002 + * https://doi.org/10.1109/IGARSS.2019.8898573 + * https://doi.org/10.3390/rs12060989 + + .. note:: + + This dataset requires the following additional libraries to be installed: + + * `pycocotools `_ to load the + ``annotations.json`` file for the "positive" image set + * `rarfile `_ to extract the dataset, + which is stored in a RAR file + """ + + image_meta = { + "url": "https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE", + "filename": "NWPU VHR-10 dataset.rar", + "md5": "d30a7ff99d92123ebb0b3a14d9102081", + } + target_meta = { + "url": ( + "https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/" + "master/NWPU%20VHR-10_dataset_coco/annotations.json" + ), + "filename": "annotations.json", + "md5": "7c76ec50c17a61bb0514050d20f22c08", + } + + categories = [ + "background", + "airplane", + "ships", + "storage tank", + "baseball diamond", + "tennis court", + "basketball court", + "ground track field", + "harbor", + "bridge", + "vehicle", + ] + + def __init__( + self, + root: str = "data", + split: str = "positive", + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new VHR-10 dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "postive" or "negative" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if ``split`` argument is invalid + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + assert split in ["positive", "negative"] + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + if download: + self._download() + + if not self._check_integrity(): + raise RuntimeError( + "Dataset not found or corrupted. " + + "You can use download=True to download it" + ) + + if split == "positive": + # Must be installed to parse annotations file + try: + from pycocotools.coco import COCO # noqa: F401 + except ImportError: + raise ImportError( + "pycocotools is not installed and is required to use this dataset" + ) + + self.coco = COCO( + os.path.join( + self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] + ) + ) + + self.coco_convert = ConvertCocoAnnotations() + self.ids = list(sorted(self.coco.imgs.keys())) + + def __getitem__(self, index: int) -> Dict[str, Any]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + id_ = index % len(self) + 1 + + sample: Dict[str, Any] = { + "image": self._load_image(id_), + "label": self._load_target(id_), + } + + if sample["label"]["annotations"]: + sample = self.coco_convert(sample) + sample["labels"] = sample["label"]["labels"] + sample["boxes"] = sample["label"]["boxes"] + if "masks" in sample["label"]: + sample["masks"] = sample["label"]["masks"] + del sample["label"] + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + if self.split == "positive": + return len(self.ids) + else: + return 150 + + def _load_image(self, id_: int) -> Tensor: + """Load a single image. + + Args: + id_: unique ID of the image + + Returns: + the image + """ + filename = os.path.join( + self.root, + "NWPU VHR-10 dataset", + self.split + " image set", + f"{id_:03d}.jpg", + ) + with Image.open(filename) as img: + array: "np.typing.NDArray[np.int_]" = np.array(img) + tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target(self, id_: int) -> Dict[str, Any]: + """Load the annotations for a single image. + + Args: + id_: unique ID of the image + + Returns: + the annotations + """ + # Images in the "negative" image set have no annotations + annot = [] + if self.split == "positive": + annot = self.coco.loadAnns(self.coco.getAnnIds(id_ - 1)) + + target = dict(image_id=id_, annotations=annot) + + return target + + def _check_integrity(self) -> bool: + """Check integrity of dataset. + + Returns: + True if dataset files are found and/or MD5s match, else False + """ + image: bool = check_integrity( + os.path.join(self.root, self.image_meta["filename"]), + self.image_meta["md5"] if self.checksum else None, + ) + + # Annotations only needed for "positive" image set + target = True + if self.split == "positive": + target = check_integrity( + os.path.join( + self.root, "NWPU VHR-10 dataset", self.target_meta["filename"] + ), + self.target_meta["md5"] if self.checksum else None, + ) + + return image and target + + def _download(self) -> None: + """Download the dataset and extract it.""" + if self._check_integrity(): + print("Files already downloaded and verified") + return + + # Download images + download_and_extract_archive( + self.image_meta["url"], + self.root, + filename=self.image_meta["filename"], + md5=self.image_meta["md5"] if self.checksum else None, + ) + + # Annotations only needed for "positive" image set + if self.split == "positive": + # Download annotations + download_url( + self.target_meta["url"], + os.path.join(self.root, "NWPU VHR-10 dataset"), + self.target_meta["filename"], + self.target_meta["md5"] if self.checksum else None, + ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + show_feats: Optional[str] = "both", + box_alpha: float = 0.7, + mask_alpha: float = 0.7, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + suptitle: optional string to use as a suptitle + show_titles: flag indicating whether to show titles above each panel + show_feats: optional string to pick features to be shown: boxes, masks, both + box_alpha: alpha value of box + mask_alpha: alpha value of mask + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.4 + """ + assert show_feats in {"boxes", "masks", "both"} + + if self.split == "negative": + plt.imshow(sample["image"].permute(1, 2, 0)) + axs = plt.gca() + axs.axis("off") + + if suptitle is not None: + plt.suptitle(suptitle) + return plt.gcf() + + image = sample["image"].permute(1, 2, 0).numpy() + boxes = sample["boxes"].cpu().numpy() + labels = sample["labels"].cpu().numpy() + if "masks" in sample: + masks = [mask.squeeze().cpu().numpy() for mask in sample["masks"]] + + N_GT = len(boxes) + + ncols = 1 + show_predictions = "prediction_labels" in sample + + if show_predictions: + show_pred_boxes = False + show_pred_masks = False + prediction_labels = sample["prediction_labels"].numpy() + prediction_scores = sample["prediction_scores"].numpy() + if "prediction_boxes" in sample: + prediction_boxes = sample["prediction_boxes"].numpy() + show_pred_boxes = True + if "prediction_masks" in sample: + prediction_masks = sample["prediction_masks"].numpy() + show_pred_masks = True + + N_PRED = len(prediction_labels) + ncols += 1 + + # Display image + fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) + if not isinstance(axs, np.ndarray): + axs = [axs] + axs[0].imshow(image) + axs[0].axis("off") + + cm = plt.get_cmap("gist_rainbow") + for i in range(N_GT): + class_num = labels[i] + color = cm(class_num / len(self.categories)) + + # Add bounding boxes + x1, y1, x2, y2 = boxes[i] + if show_feats in {"boxes", "both"}: + p = patches.Rectangle( + (x1, y1), + x2 - x1, + y2 - y1, + linewidth=2, + alpha=box_alpha, + linestyle="dashed", + edgecolor=color, + facecolor="none", + ) + axs[0].add_patch(p) + + # Add labels + label = self.categories[class_num] + caption = label + axs[0].text( + x1, y1 - 8, caption, color="white", size=11, backgroundcolor="none" + ) + + # Add masks + if show_feats in {"masks", "both"} and "masks" in sample: + mask = masks[i] + contours = find_contours(mask, 0.5) + for verts in contours: + verts = np.fliplr(verts) + p = patches.Polygon( + verts, facecolor=color, alpha=mask_alpha, edgecolor="white" + ) + axs[0].add_patch(p) + + if show_titles: + axs[0].set_title("Ground Truth") + + if show_predictions: + axs[1].imshow(image) + axs[1].axis("off") + for i in range(N_PRED): + score = prediction_scores[i] + if score < 0.5: + continue + + # TODO: Check scores + + class_num = prediction_labels[i] + color = cm(class_num / len(self.categories)) + + if show_pred_boxes: + # Add bounding boxes + x1, y1, x2, y2 = prediction_boxes[i] + p = patches.Rectangle( + (x1, y1), + x2 - x1, + y2 - y1, + linewidth=2, + alpha=box_alpha, + linestyle="dashed", + edgecolor=color, + facecolor="none", + ) + axs[1].add_patch(p) + + # Add labels + label = self.categories[class_num] + caption = f"{label} {score:.3f}" + axs[1].text( + x1, + y1 - 8, + caption, + color="white", + size=11, + backgroundcolor="none", + ) + + # TODO: Labels are dependent on boxes being shown + # Add masks + if show_pred_masks: + mask = prediction_masks[i] + contours = find_contours(mask, 0.5) + for verts in contours: + verts = np.fliplr(verts) + p = patches.Polygon( + verts, facecolor=color, alpha=mask_alpha, edgecolor="white" + ) + axs[1].add_patch(p) + + if show_titles: + axs[1].set_title("Prediction") + + plt.tight_layout() + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/train.py b/train.py index 3f3fd7fb77f..8b4f1d0e153 100755 --- a/train.py +++ b/train.py @@ -29,6 +29,7 @@ So2SatDataModule, TropicalCycloneDataModule, UCMercedDataModule, + VHR10DataModule, ) from torchgeo.trainers import ( BYOLTask, @@ -58,6 +59,7 @@ "sen12ms": (SemanticSegmentationTask, SEN12MSDataModule), "so2sat": (ClassificationTask, So2SatDataModule), "ucmerced": (ClassificationTask, UCMercedDataModule), + "vhr10": (ObjectDetectionTask, VHR10DataModule), }