diff --git a/conf/vhr10.yaml b/conf/vhr10.yaml
new file mode 100644
index 00000000000..71a7cbaa43b
--- /dev/null
+++ b/conf/vhr10.yaml
@@ -0,0 +1,28 @@
+program:
+ seed: 0
+ overwrite: True
+
+trainer:
+ gpus: 1
+ min_epochs: 5
+ max_epochs: 100
+ auto_lr_find: False
+ benchmark: True
+
+experiment:
+ task: "vhr10"
+ name: "vhr10_test"
+ module:
+ detection_model: "faster-rcnn"
+ backbone: "resnet50"
+ pretrained: True
+ num_classes: 11
+ learning_rate: 1.3e-5
+ learning_rate_schedule_patience: 6
+ verbose: false
+ datamodule:
+ root: "data/vhr10"
+ batch_size: 2
+ patch_size: 512
+ num_workers: 56
+ val_split_pct: 0.2
diff --git a/environment.yml b/environment.yml
index dafd0778a56..c82849aaa35 100644
--- a/environment.yml
+++ b/environment.yml
@@ -16,6 +16,7 @@ dependencies:
- pytorch>=1.9
- rarfile>=3
- rasterio>=1.0.20
+ - scikit-image>=0.15.0
- shapely>=1.3
- torchvision>=0.10
- pip:
diff --git a/requirements/min.old b/requirements/min.old
index 54d85ac290a..cad884bfe9a 100644
--- a/requirements/min.old
+++ b/requirements/min.old
@@ -14,6 +14,7 @@ pyproj==2.2.0
pytorch-lightning==1.5.1
rasterio==1.0.20
rtree==1.0.0
+scikit-image==0.15.0
scikit-learn==0.21.0
segmentation-models-pytorch==0.2.0
shapely==1.3.0
@@ -28,10 +29,11 @@ laspy==2.0.0
open3d==0.11.2
opencv-python==3.4.2.17
pandas==0.23.2
-pycocotools==2.0.0
+pycocotools==2.0.1
radiant-mlhub==0.2.1
rarfile==3.0
scipy==1.2.0
+scikit-image==0.15.0
zipfile-deflate64==0.2.0
# docs
diff --git a/requirements/required.old b/requirements/required.old
index 1e120e0948c..1bb2b38d75d 100644
--- a/requirements/required.old
+++ b/requirements/required.old
@@ -17,6 +17,7 @@ pytorch-lightning==1.6.4
rasterio==1.3.0;python_version>='3.8'
rasterio==1.2.10;python_version=='3.7'
rtree==1.0.0
+scikit-image>=0.15.0;
scikit-learn==1.1.1;python_version>='3.8'
scikit-learn==1.0.2;python_version=='3.7'
segmentation-models-pytorch==0.2.1
diff --git a/requirements/required.txt b/requirements/required.txt
index 0400203b78a..604eb667c6c 100644
--- a/requirements/required.txt
+++ b/requirements/required.txt
@@ -14,6 +14,7 @@ pyproj==3.4.0;python_version>='3.8'
pytorch-lightning==1.7.7
rasterio==1.3.2;python_version>='3.8'
rtree==1.0.1
+scikit-image>=0.15.0;
scikit-learn==1.1.2;python_version>='3.8'
segmentation-models-pytorch==0.3.0
shapely==1.8.5.post1
diff --git a/setup.cfg b/setup.cfg
index 58559d5f52e..50a237da810 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -51,6 +51,8 @@ install_requires =
rtree>=1,<2
# scikit-learn 0.21+ required to fix murmurhash3_32 import bug
scikit-learn>=0.21,<2
+ # scikit-image required for find_contours
+ scikit-image>=0.15.0,<0.20
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2,<0.4
# shapely 1.3+ required for Python 3 support
diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml
new file mode 100644
index 00000000000..ad432e047b9
--- /dev/null
+++ b/tests/conf/vhr10.yaml
@@ -0,0 +1,16 @@
+
+experiment:
+ task: "vhr10"
+ module:
+ detection_model: "faster-rcnn"
+ backbone: "resnet18"
+ num_classes: 11
+ learning_rate: 1e-4
+ learning_rate_schedule_patience: 6
+ verbose: false
+ datamodule:
+ root: "tests/data/vhr10"
+ seed: 0
+ batch_size: 1
+ num_workers: 0
+ patch_size: 4
diff --git a/tests/data/vhr10/NWPU VHR-10 dataset.rar b/tests/data/vhr10/NWPU VHR-10 dataset.rar
index 5fc4953cef1..3dc735dd4b6 100644
Binary files a/tests/data/vhr10/NWPU VHR-10 dataset.rar and b/tests/data/vhr10/NWPU VHR-10 dataset.rar differ
diff --git a/tests/data/vhr10/annotations.json b/tests/data/vhr10/annotations.json
index 406ddc21a9b..60de0cb14f3 100644
--- a/tests/data/vhr10/annotations.json
+++ b/tests/data/vhr10/annotations.json
@@ -1,134 +1 @@
-{
- "info": {
- "description": null,
- "url": null,
- "version": null,
- "year": 2021,
- "contributor": null,
- "date_created": "2021-01-01 00:00:00"
- },
- "licenses": [
- {
- "url": null,
- "id": 0,
- "name": null
- }
- ],
- "images": [
- {
- "license": 0,
- "url": null,
- "file_name": "001.jpg",
- "height": 1,
- "width": 1,
- "date_captured": null,
- "id": 0
- },
- {
- "license": 0,
- "url": null,
- "file_name": "002.jpg",
- "height": 1,
- "width": 1,
- "date_captured": null,
- "id": 1
- }
- ],
- "type": "instances",
- "annotations": [
- {
- "id": 0,
- "image_id": 0,
- "category_id": 1,
- "segmentation": [
- [
- 1,
- 2,
- 3,
- 4
- ]
- ],
- "area": 1.0,
- "bbox": [
- 1,
- 2,
- 3,
- 4
- ],
- "iscrowd": 0
- },
- {
- "id": 1,
- "image_id": 1,
- "category_id": 1,
- "segmentation": [
- [
- 1,
- 2,
- 3,
- 4
- ]
- ],
- "area": 1.0,
- "bbox": [
- 1,
- 2,
- 3,
- 4
- ],
- "iscrowd": 0
- }
- ],
- "categories": [
- {
- "supercategory": null,
- "id": 1,
- "name": "airplane"
- },
- {
- "supercategory": null,
- "id": 2,
- "name": "ship"
- },
- {
- "supercategory": null,
- "id": 3,
- "name": "storage_tank"
- },
- {
- "supercategory": null,
- "id": 4,
- "name": "baseball_diamond"
- },
- {
- "supercategory": null,
- "id": 5,
- "name": "tennis_court"
- },
- {
- "supercategory": null,
- "id": 6,
- "name": "basketball_court"
- },
- {
- "supercategory": null,
- "id": 7,
- "name": "ground_track_field"
- },
- {
- "supercategory": null,
- "id": 8,
- "name": "harbor"
- },
- {
- "supercategory": null,
- "id": 9,
- "name": "bridge"
- },
- {
- "supercategory": null,
- "id": 10,
- "name": "vehicle"
- }
- ]
-}
+{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]}
\ No newline at end of file
diff --git a/tests/data/vhr10/data.py b/tests/data/vhr10/data.py
new file mode 100644
index 00000000000..3be22226954
--- /dev/null
+++ b/tests/data/vhr10/data.py
@@ -0,0 +1,104 @@
+import json
+import os
+import shutil
+import subprocess
+import warnings
+from copy import deepcopy
+
+import numpy as np
+import rasterio as rio
+from rasterio.errors import NotGeoreferencedWarning
+from torchvision.datasets.utils import calculate_md5
+
+ANNOTATION_FILE = {"images": [], "annotations": []}
+
+
+def write_data(path: str, img: np.ndarray) -> None:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=NotGeoreferencedWarning)
+ with rio.open(
+ path,
+ "w",
+ driver="JP2OpenJPEG",
+ height=img.shape[0],
+ width=img.shape[1],
+ count=3,
+ dtype=img.dtype,
+ ) as dst:
+ for i in range(1, dst.count + 1):
+ dst.write(img, i)
+
+
+def generate_test_data(root: str, n_imgs: int = 3) -> str:
+ folder_path = os.path.join(root, "NWPU VHR-10 dataset")
+ pos_img_dir = os.path.join(folder_path, "positive image set")
+ neg_img_dir = os.path.join(folder_path, "negative image set")
+ ann_file = os.path.join(folder_path, "annotations.json")
+ ann_file2 = os.path.join(root, "annotations.json")
+
+ if not os.path.exists(pos_img_dir):
+ os.makedirs(pos_img_dir)
+ if not os.path.exists(neg_img_dir):
+ os.makedirs(neg_img_dir)
+
+ for img_id in range(1, n_imgs + 1):
+ pos_img_name = os.path.join(pos_img_dir, f"00{img_id}.jpg")
+ neg_img_name = os.path.join(neg_img_dir, f"00{img_id}.jpg")
+
+ img = np.random.randint(255, size=(8, 8), dtype=np.dtype("uint8"))
+ write_data(pos_img_name, img)
+ write_data(neg_img_name, img)
+
+ img_name = os.path.basename(pos_img_name)
+
+ ANNOTATION_FILE["images"].append(
+ {"file_name": img_name, "height": 8, "width": 8, "id": img_id - 1}
+ )
+
+ ann = 0
+ import pdb
+
+ pdb.set_trace()
+ for i, img in enumerate(ANNOTATION_FILE["images"]):
+ annot = {
+ "id": ann,
+ "image_id": img["id"],
+ "category_id": 1,
+ "area": 4.0,
+ "bbox": [4, 4, 2, 2],
+ "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]],
+ "iscrowd": 0,
+ }
+ if i != 0:
+ ANNOTATION_FILE["annotations"].append(annot)
+ else:
+ noseg_annot = deepcopy(annot)
+ del noseg_annot["segmentation"]
+ ANNOTATION_FILE["annotations"].append(noseg_annot)
+ ann += 1
+ import pdb
+
+ pdb.set_trace()
+ with open(ann_file, "w") as j:
+ json.dump(ANNOTATION_FILE, j)
+
+ with open(ann_file2, "w") as j:
+ json.dump(ANNOTATION_FILE, j)
+
+ # Create rar file
+ subprocess.run(
+ ["rar", "a", "NWPU VHR-10 dataset.rar", "-m5", "NWPU VHR-10 dataset"],
+ capture_output=True,
+ check=True,
+ )
+
+ annotations_md5 = calculate_md5(ann_file)
+ archive_md5 = calculate_md5("NWPU VHR-10 dataset.rar")
+ shutil.rmtree(folder_path)
+
+ return f"archive md5: {archive_md5}, annotation md5: {annotations_md5}"
+
+
+if __name__ == "__main__":
+ md5 = generate_test_data(os.getcwd(), 5)
+ print(md5)
diff --git a/tests/datamodules/test_vhr10.py b/tests/datamodules/test_vhr10.py
new file mode 100644
index 00000000000..3fbb9a6bd29
--- /dev/null
+++ b/tests/datamodules/test_vhr10.py
@@ -0,0 +1,37 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+
+import pytest
+
+from torchgeo.datamodules import VHR10DataModule
+
+
+class TestVHR10DataModule:
+ @pytest.fixture(scope="class")
+ def datamodule(self) -> VHR10DataModule:
+ root = os.path.join("tests", "data", "vhr10")
+ batch_size = 1
+ num_workers = 0
+ val_split_pct = 0.4
+ test_split_pct = 0.2
+ dm = VHR10DataModule(
+ root=root,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ val_split_pct=val_split_pct,
+ test_split_pct=test_split_pct,
+ )
+ dm.prepare_data()
+ dm.setup()
+ return dm
+
+ def test_train_dataloader(self, datamodule: VHR10DataModule) -> None:
+ next(iter(datamodule.train_dataloader()))
+
+ def test_val_dataloader(self, datamodule: VHR10DataModule) -> None:
+ next(iter(datamodule.val_dataloader()))
+
+ def test_test_dataloader(self, datamodule: VHR10DataModule) -> None:
+ next(iter(datamodule.test_dataloader()))
diff --git a/tests/datasets/test_nwpu.py b/tests/datasets/test_vhr10.py
similarity index 68%
rename from tests/datasets/test_nwpu.py
rename to tests/datasets/test_vhr10.py
index 582b6730908..d2400ee706c 100644
--- a/tests/datasets/test_nwpu.py
+++ b/tests/datasets/test_vhr10.py
@@ -7,6 +7,7 @@
from pathlib import Path
from typing import Any
+import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
@@ -30,15 +31,15 @@ def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> VHR10:
pytest.importorskip("rarfile", minversion="3")
- monkeypatch.setattr(torchgeo.datasets.nwpu, "download_url", download_url)
+ monkeypatch.setattr(torchgeo.datasets.vhr10, "download_url", download_url)
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
url = os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar")
monkeypatch.setitem(VHR10.image_meta, "url", url)
- md5 = "e5c38351bd948479fe35a71136aedbc4"
+ md5 = "e68727b2c91ac849def1985924d9586f"
monkeypatch.setitem(VHR10.image_meta, "md5", md5)
url = os.path.join("tests", "data", "vhr10", "annotations.json")
monkeypatch.setitem(VHR10.target_meta, "url", url)
- md5 = "16fc6aa597a19179dad84151cc221873"
+ md5 = "833899cce369168e0d4ee420dac326dc"
monkeypatch.setitem(VHR10.target_meta, "md5", md5)
root = str(tmp_path)
split = request.param
@@ -57,14 +58,19 @@ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
monkeypatch.setattr(builtins, "__import__", mocked_import)
def test_getitem(self, dataset: VHR10) -> None:
- x = dataset[0]
- assert isinstance(x, dict)
- assert isinstance(x["image"], torch.Tensor)
- assert isinstance(x["label"], dict)
+ for i in range(2):
+ x = dataset[i]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ if dataset.split == "positive":
+ assert isinstance(x["labels"], torch.Tensor)
+ assert isinstance(x["boxes"], torch.Tensor)
+ if "masks" in x:
+ assert isinstance(x["masks"], torch.Tensor)
def test_len(self, dataset: VHR10) -> None:
if dataset.split == "positive":
- assert len(dataset) == 650
+ assert len(dataset) == len(dataset.ids)
elif dataset.split == "negative":
assert len(dataset) == 150
@@ -72,7 +78,7 @@ def test_add(self, dataset: VHR10) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
if dataset.split == "positive":
- assert len(ds) == 1300
+ assert len(ds) == 10
elif dataset.split == "negative":
assert len(ds) == 300
@@ -96,3 +102,21 @@ def test_mock_missing_module(
match="pycocotools is not installed and is required to use this datase",
):
VHR10(dataset.root, dataset.split)
+
+ def test_plot(self, dataset: VHR10) -> None:
+ x = dataset[1].copy()
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+ dataset.plot(x, show_titles=False)
+ plt.close()
+ if dataset.split == "positive":
+ scores = [0.7, 0.3, 0.7]
+ for i in range(3):
+ x = dataset[i]
+ x["prediction_labels"] = x["labels"]
+ x["prediction_boxes"] = x["boxes"]
+ x["prediction_scores"] = torch.Tensor([scores[i]])
+ if "masks" in x:
+ x["prediction_masks"] = x["masks"]
+ dataset.plot(x, show_feats="masks")
+ plt.close()
diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py
index 0f2d8a65cc6..f2d9e78eb22 100644
--- a/tests/trainers/test_detection.py
+++ b/tests/trainers/test_detection.py
@@ -9,13 +9,17 @@
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
-from torchgeo.datamodules import NASAMarineDebrisDataModule
+from torchgeo.datamodules import NASAMarineDebrisDataModule, VHR10DataModule
from torchgeo.trainers import ObjectDetectionTask
class TestObjectDetectionTask:
@pytest.mark.parametrize(
- "name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)]
+ "name,classname",
+ [
+ ("nasa_marine_debris", NASAMarineDebrisDataModule),
+ ("vhr10", VHR10DataModule),
+ ],
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml"))
diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py
index 2230cf00eaf..72e44adf836 100644
--- a/torchgeo/datamodules/__init__.py
+++ b/torchgeo/datamodules/__init__.py
@@ -24,6 +24,7 @@
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .vaihingen import Vaihingen2DDataModule
+from .vhr10 import VHR10DataModule
from .xview import XView2DataModule
__all__ = (
@@ -50,6 +51,7 @@
"UCMercedDataModule",
"USAVarsDataModule",
"Vaihingen2DDataModule",
+ "VHR10DataModule",
"XView2DataModule",
)
diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py
new file mode 100644
index 00000000000..2e0e3eb35c4
--- /dev/null
+++ b/torchgeo/datamodules/vhr10.py
@@ -0,0 +1,170 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""NWPU VHR-10 datamodule."""
+
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
+
+import matplotlib.pyplot as plt
+import pytorch_lightning as pl
+import torch
+import torchvision
+from torch import Tensor
+from torch.utils.data import DataLoader
+
+from ..datasets import VHR10
+from ..samplers.utils import _to_tuple
+from .utils import dataset_split
+
+# https://github.com/pytorch/pytorch/issues/60979
+# https://github.com/pytorch/pytorch/pull/61045
+DataLoader.__module__ = "torch.utils.data"
+
+
+def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
+ """Custom object detection collate fn to handle variable boxes.
+
+ Args:
+ batch: list of sample dicts return by dataset
+
+ Returns:
+ batch dict output
+ """
+ output: Dict[str, Any] = {}
+ output["image"] = torch.stack([sample["image"] for sample in batch])
+ output["labels"] = [sample["labels"] for sample in batch]
+ output["boxes"] = [sample["boxes"] for sample in batch]
+ if "masks" in batch[0]:
+ output["masks"] = [sample["masks"] for sample in batch]
+ return output
+
+
+class VHR10DataModule(pl.LightningDataModule):
+ """LightningDataModule implementation for the VHR10 dataset.
+
+ .. versionadded:: 0.4
+ """
+
+ def __init__(
+ self,
+ root: str,
+ batch_size: int = 64,
+ num_workers: int = 0,
+ val_split_pct: float = 0.2,
+ test_split_pct: float = 0.2,
+ patch_size: Union[int, Tuple[int, int]] = 512,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize a LightningDataModule for VHR10 based DataLoaders.
+
+ Args:
+ root: The ``root`` argument to pass to the Dataset class
+ batch_size: The batch size to use in all created DataLoaders
+ num_workers: The number of workers to use in all created DataLoaders
+ val_split_pct: What percentage of the dataset to use as a validation set
+ test_split_pct: What percentage of the dataset to use as a test set
+ patch_size: Patch size (height, width) for batched training
+ """
+ super().__init__()
+ self.root = root
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.val_split_pct = val_split_pct
+ self.test_split_pct = test_split_pct
+ self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size))
+
+ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
+ """Transform a single sample from the Dataset.
+
+ Args:
+ sample: input image dictionary
+
+ Returns:
+ preprocessed sample
+ """
+ sample["image"] = sample["image"].float()
+ sample["image"] /= 255.0
+
+ _, h, w = sample["image"].shape
+ sample["image"] = torchvision.transforms.functional.resize(
+ sample["image"], size=self.patch_size
+ )
+ box_scale = (self.patch_size[1] / w, self.patch_size[0] / h)
+ sample["boxes"][:, 0] *= box_scale[0]
+ sample["boxes"][:, 1] *= box_scale[1]
+ sample["boxes"][:, 2] *= box_scale[0]
+ sample["boxes"][:, 3] *= box_scale[1]
+ sample["boxes"] = torch.round(sample["boxes"])
+
+ if "masks" in sample:
+ sample["masks"] = torchvision.transforms.functional.resize(
+ sample["masks"], size=self.patch_size
+ )
+
+ return sample
+
+ def prepare_data(self) -> None:
+ """Make sure that the dataset is downloaded.
+
+ This method is only called once per run.
+ """
+ VHR10(self.root, download=True, checksum=False)
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ """Initialize the main ``Dataset`` objects.
+
+ This method is called once per GPU per run.
+
+ Args:
+ stage: stage to set up
+ """
+ self.dataset = VHR10(self.root, transforms=self.preprocess)
+ self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
+ self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
+ )
+
+ def train_dataloader(self) -> DataLoader[Any]:
+ """Return a DataLoader for training.
+
+ Returns:
+ training data loader
+ """
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=True,
+ collate_fn=collate_fn,
+ )
+
+ def val_dataloader(self) -> DataLoader[Any]:
+ """Return a DataLoader for validation.
+
+ Returns:
+ validation data loader
+ """
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
+ collate_fn=collate_fn,
+ )
+
+ def test_dataloader(self) -> DataLoader[Any]:
+ """Return a DataLoader for testing.
+
+ Returns:
+ testing data loader
+ """
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
+ collate_fn=collate_fn,
+ )
+
+ def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
+ """Run :meth:`torchgeo.datasets.VHR10.plot`."""
+ return self.dataset.plot(*args, **kwargs)
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index 6b6aa6fbc6e..a0a9ea0db1a 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -74,7 +74,6 @@
from .millionaid import MillionAID
from .naip import NAIP
from .nasa_marine_debris import NASAMarineDebris
-from .nwpu import VHR10
from .openbuildings import OpenBuildings
from .oscd import OSCD
from .patternnet import PatternNet
@@ -104,6 +103,7 @@
unbind_samples,
)
from .vaihingen import Vaihingen2D
+from .vhr10 import VHR10
from .xview import XView2
from .zuericrop import ZueriCrop
diff --git a/torchgeo/datasets/nwpu.py b/torchgeo/datasets/nwpu.py
deleted file mode 100644
index 12bada1a674..00000000000
--- a/torchgeo/datasets/nwpu.py
+++ /dev/null
@@ -1,246 +0,0 @@
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-
-"""NWPU VHR-10 dataset."""
-
-import os
-from typing import Any, Callable, Dict, Optional
-
-import numpy as np
-import torch
-from PIL import Image
-from torch import Tensor
-
-from .geo import NonGeoDataset
-from .utils import check_integrity, download_and_extract_archive, download_url
-
-
-class VHR10(NonGeoDataset):
- """NWPU VHR-10 dataset.
-
- Northwestern Polytechnical University (NWPU) very-high-resolution ten-class (VHR-10)
- remote sensing image dataset.
-
- Consists of 800 VHR optical remote sensing images, where 715 color images were
- acquired from Google Earth with the spatial resolution ranging from 0.5 to 2 m,
- and 85 pansharpened color infrared (CIR) images were acquired from Vaihingen data
- with a spatial resolution of 0.08 m.
-
- The data set is divided into two sets:
-
- * Positive image set (650 images) which contains at least one target in an image
- * Negative image set (150 images) does not contain any targets
-
- The positive image set consists of objects from ten classes:
-
- 1. Airplanes (757)
- 2. Ships (302)
- 3. Storage tanks (655)
- 4. Baseball diamonds (390)
- 5. Tennis courts (524)
- 6. Basketball courts (159)
- 7. Ground track fields (163)
- 8. Harbors (224)
- 9. Bridges (124)
- 10. Vehicles (477)
-
- Includes object detection bounding boxes from original paper and instance
- segmentation masks from follow-up publications. If you use this dataset in your
- research, please cite the following papers:
-
- * https://doi.org/10.1016/j.isprsjprs.2014.10.002
- * https://doi.org/10.1109/IGARSS.2019.8898573
- * https://doi.org/10.3390/rs12060989
-
- .. note::
-
- This dataset requires the following additional libraries to be installed:
-
- * `pycocotools `_ to load the
- ``annotations.json`` file for the "positive" image set
- * `rarfile `_ to extract the dataset,
- which is stored in a RAR file
- """
-
- image_meta = {
- "url": "https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE",
- "filename": "NWPU VHR-10 dataset.rar",
- "md5": "d30a7ff99d92123ebb0b3a14d9102081",
- }
- target_meta = {
- "url": (
- "https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/"
- "master/NWPU%20VHR-10_dataset_coco/annotations.json"
- ),
- "filename": "annotations.json",
- "md5": "7c76ec50c17a61bb0514050d20f22c08",
- }
-
- def __init__(
- self,
- root: str = "data",
- split: str = "positive",
- transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
- download: bool = False,
- checksum: bool = False,
- ) -> None:
- """Initialize a new VHR-10 dataset instance.
-
- Args:
- root: root directory where dataset can be found
- split: one of "postive" or "negative"
- transforms: a function/transform that takes input sample and its target as
- entry and returns a transformed version
- download: if True, download dataset and store it in the root directory
- checksum: if True, check the MD5 of the downloaded files (may be slow)
-
- Raises:
- AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
- """
- assert split in ["positive", "negative"]
-
- self.root = root
- self.split = split
- self.transforms = transforms
- self.checksum = checksum
-
- if download:
- self._download()
-
- if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
-
- if split == "positive":
- # Must be installed to parse annotations file
- try:
- from pycocotools.coco import COCO # noqa: F401
- except ImportError:
- raise ImportError(
- "pycocotools is not installed and is required to use this dataset"
- )
-
- self.coco = COCO(
- os.path.join(
- self.root, "NWPU VHR-10 dataset", self.target_meta["filename"]
- )
- )
-
- def __getitem__(self, index: int) -> Dict[str, Any]:
- """Return an index within the dataset.
-
- Args:
- index: index to return
-
- Returns:
- data and label at that index
- """
- id_ = index % len(self) + 1
- sample = {"image": self._load_image(id_), "label": self._load_target(id_)}
-
- if self.transforms is not None:
- sample = self.transforms(sample)
-
- return sample
-
- def __len__(self) -> int:
- """Return the number of data points in the dataset.
-
- Returns:
- length of the dataset
- """
- if self.split == "positive":
- return 650
- else:
- return 150
-
- def _load_image(self, id_: int) -> Tensor:
- """Load a single image.
-
- Args:
- id_: unique ID of the image
-
- Returns:
- the image
- """
- filename = os.path.join(
- self.root,
- "NWPU VHR-10 dataset",
- self.split + " image set",
- f"{id_:03d}.jpg",
- )
- with Image.open(filename) as img:
- array: "np.typing.NDArray[np.int_]" = np.array(img)
- tensor = torch.from_numpy(array)
- # Convert from HxWxC to CxHxW
- tensor = tensor.permute((2, 0, 1))
- return tensor
-
- def _load_target(self, id_: int) -> Dict[str, Any]:
- """Load the annotations for a single image.
-
- Args:
- id_: unique ID of the image
-
- Returns:
- the annotations
- """
- # Images in the "negative" image set have no annotations
- annot = []
- if self.split == "positive":
- annot = self.coco.loadAnns(self.coco.getAnnIds(id_))
-
- target = dict(image_id=id_, annotations=annot)
-
- return target
-
- def _check_integrity(self) -> bool:
- """Check integrity of dataset.
-
- Returns:
- True if dataset files are found and/or MD5s match, else False
- """
- image: bool = check_integrity(
- os.path.join(self.root, self.image_meta["filename"]),
- self.image_meta["md5"] if self.checksum else None,
- )
-
- # Annotations only needed for "positive" image set
- target = True
- if self.split == "positive":
- target = check_integrity(
- os.path.join(
- self.root, "NWPU VHR-10 dataset", self.target_meta["filename"]
- ),
- self.target_meta["md5"] if self.checksum else None,
- )
-
- return image and target
-
- def _download(self) -> None:
- """Download the dataset and extract it."""
- if self._check_integrity():
- print("Files already downloaded and verified")
- return
-
- # Download images
- download_and_extract_archive(
- self.image_meta["url"],
- self.root,
- filename=self.image_meta["filename"],
- md5=self.image_meta["md5"] if self.checksum else None,
- )
-
- # Annotations only needed for "positive" image set
- if self.split == "positive":
- # Download annotations
- download_url(
- self.target_meta["url"],
- os.path.join(self.root, "NWPU VHR-10 dataset"),
- self.target_meta["filename"],
- self.target_meta["md5"] if self.checksum else None,
- )
diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py
new file mode 100644
index 00000000000..ade678ddbf5
--- /dev/null
+++ b/torchgeo/datasets/vhr10.py
@@ -0,0 +1,540 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""NWPU VHR-10 dataset."""
+
+import os
+from typing import Any, Callable, Dict, List, Optional
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from matplotlib import patches
+from PIL import Image
+from skimage.measure import find_contours
+from torch import Tensor
+
+from .geo import NonGeoDataset
+from .utils import check_integrity, download_and_extract_archive, download_url
+
+
+def convert_coco_poly_to_mask(
+ segmentations: List[int], height: int, width: int
+) -> Tensor:
+ """Convert coco polygons to mask tensor.
+
+ Args:
+ segmentations (List[int]): polygon coordinates
+ height (int): image height
+ width (int): image width
+
+ Returns:
+ Tensor: Mask tensor
+ """
+ from pycocotools import mask as coco_mask # noqa: F401
+
+ masks = []
+ for polygons in segmentations:
+ rles = coco_mask.frPyObjects(polygons, height, width)
+ mask = coco_mask.decode(rles)
+ if len(mask.shape) < 3:
+ mask = mask[..., None]
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
+ mask = mask.any(dim=2)
+ masks.append(mask)
+ if masks:
+ masks_tensor = torch.stack(masks, dim=0)
+ else:
+ masks_tensor = torch.zeros((0, height, width), dtype=torch.uint8)
+ return masks_tensor
+
+
+class ConvertCocoAnnotations:
+ """Callable for converting the boxes, masks and labels into tensors.
+
+ This is a modified version of ConvertCocoPolysToMask() from torchvision found in
+ https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py
+ """
+
+ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
+ """Converts MS COCO fields (boxes, masks & labels) from list of ints to tensors.
+
+ Args:
+ sample (Dict[str, Any]): Sample
+
+ Returns:
+ Dict[str, Any]: Processed sample
+ """
+ image = sample["image"]
+ _, h, w = image.size()
+ target = sample["label"]
+
+ image_id = target["image_id"]
+ image_id = torch.tensor([image_id])
+
+ anno = target["annotations"]
+
+ anno = [obj for obj in anno if obj["iscrowd"] == 0]
+
+ bboxes = [obj["bbox"] for obj in anno]
+ # guard against no boxes via resizing
+ boxes = torch.as_tensor(bboxes, dtype=torch.float32).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2].clamp_(min=0, max=w)
+ boxes[:, 1::2].clamp_(min=0, max=h)
+
+ categories = [obj["category_id"] for obj in anno]
+ classes = torch.tensor(categories, dtype=torch.int64)
+
+ if "segmentation" in anno[0]:
+ segmentations = [obj["segmentation"] for obj in anno]
+ else:
+ segmentations = []
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
+
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+ boxes = boxes[keep]
+ classes = classes[keep]
+
+ target = {"boxes": boxes, "labels": classes, "image_id": image_id}
+ if masks.nelement() > 0:
+ masks = masks[keep]
+ target["masks"] = masks
+
+ # for conversion to coco api
+ area = torch.tensor([obj["area"] for obj in anno])
+ iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
+ target["area"] = area
+ target["iscrowd"] = iscrowd
+ return {"image": image, "label": target}
+
+
+class VHR10(NonGeoDataset):
+ """NWPU VHR-10 dataset.
+
+ Northwestern Polytechnical University (NWPU) very-high-resolution ten-class (VHR-10)
+ remote sensing image dataset.
+
+ Consists of 800 VHR optical remote sensing images, where 715 color images were
+ acquired from Google Earth with the spatial resolution ranging from 0.5 to 2 m,
+ and 85 pansharpened color infrared (CIR) images were acquired from Vaihingen data
+ with a spatial resolution of 0.08 m.
+
+ The data set is divided into two sets:
+
+ * Positive image set (650 images) which contains at least one target in an image
+ * Negative image set (150 images) does not contain any targets
+
+ The positive image set consists of objects from ten classes:
+
+ 1. Airplanes (757)
+ 2. Ships (302)
+ 3. Storage tanks (655)
+ 4. Baseball diamonds (390)
+ 5. Tennis courts (524)
+ 6. Basketball courts (159)
+ 7. Ground track fields (163)
+ 8. Harbors (224)
+ 9. Bridges (124)
+ 10. Vehicles (477)
+
+ Includes object detection bounding boxes from original paper and instance
+ segmentation masks from follow-up publications. If you use this dataset in your
+ research, please cite the following papers:
+
+ * https://doi.org/10.1016/j.isprsjprs.2014.10.002
+ * https://doi.org/10.1109/IGARSS.2019.8898573
+ * https://doi.org/10.3390/rs12060989
+
+ .. note::
+
+ This dataset requires the following additional libraries to be installed:
+
+ * `pycocotools `_ to load the
+ ``annotations.json`` file for the "positive" image set
+ * `rarfile `_ to extract the dataset,
+ which is stored in a RAR file
+ """
+
+ image_meta = {
+ "url": "https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE",
+ "filename": "NWPU VHR-10 dataset.rar",
+ "md5": "d30a7ff99d92123ebb0b3a14d9102081",
+ }
+ target_meta = {
+ "url": (
+ "https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/"
+ "master/NWPU%20VHR-10_dataset_coco/annotations.json"
+ ),
+ "filename": "annotations.json",
+ "md5": "7c76ec50c17a61bb0514050d20f22c08",
+ }
+
+ categories = [
+ "background",
+ "airplane",
+ "ships",
+ "storage tank",
+ "baseball diamond",
+ "tennis court",
+ "basketball court",
+ "ground track field",
+ "harbor",
+ "bridge",
+ "vehicle",
+ ]
+
+ def __init__(
+ self,
+ root: str = "data",
+ split: str = "positive",
+ transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ download: bool = False,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new VHR-10 dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ split: one of "postive" or "negative"
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ download: if True, download dataset and store it in the root directory
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ AssertionError: if ``split`` argument is invalid
+ RuntimeError: if ``download=False`` and data is not found, or checksums
+ don't match
+ """
+ assert split in ["positive", "negative"]
+
+ self.root = root
+ self.split = split
+ self.transforms = transforms
+ self.checksum = checksum
+
+ if download:
+ self._download()
+
+ if not self._check_integrity():
+ raise RuntimeError(
+ "Dataset not found or corrupted. "
+ + "You can use download=True to download it"
+ )
+
+ if split == "positive":
+ # Must be installed to parse annotations file
+ try:
+ from pycocotools.coco import COCO # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "pycocotools is not installed and is required to use this dataset"
+ )
+
+ self.coco = COCO(
+ os.path.join(
+ self.root, "NWPU VHR-10 dataset", self.target_meta["filename"]
+ )
+ )
+
+ self.coco_convert = ConvertCocoAnnotations()
+ self.ids = list(sorted(self.coco.imgs.keys()))
+
+ def __getitem__(self, index: int) -> Dict[str, Any]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ data and label at that index
+ """
+ id_ = index % len(self) + 1
+
+ sample: Dict[str, Any] = {
+ "image": self._load_image(id_),
+ "label": self._load_target(id_),
+ }
+
+ if sample["label"]["annotations"]:
+ sample = self.coco_convert(sample)
+ sample["labels"] = sample["label"]["labels"]
+ sample["boxes"] = sample["label"]["boxes"]
+ if "masks" in sample["label"]:
+ sample["masks"] = sample["label"]["masks"]
+ del sample["label"]
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def __len__(self) -> int:
+ """Return the number of data points in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ if self.split == "positive":
+ return len(self.ids)
+ else:
+ return 150
+
+ def _load_image(self, id_: int) -> Tensor:
+ """Load a single image.
+
+ Args:
+ id_: unique ID of the image
+
+ Returns:
+ the image
+ """
+ filename = os.path.join(
+ self.root,
+ "NWPU VHR-10 dataset",
+ self.split + " image set",
+ f"{id_:03d}.jpg",
+ )
+ with Image.open(filename) as img:
+ array: "np.typing.NDArray[np.int_]" = np.array(img)
+ tensor = torch.from_numpy(array)
+ # Convert from HxWxC to CxHxW
+ tensor = tensor.permute((2, 0, 1))
+ return tensor
+
+ def _load_target(self, id_: int) -> Dict[str, Any]:
+ """Load the annotations for a single image.
+
+ Args:
+ id_: unique ID of the image
+
+ Returns:
+ the annotations
+ """
+ # Images in the "negative" image set have no annotations
+ annot = []
+ if self.split == "positive":
+ annot = self.coco.loadAnns(self.coco.getAnnIds(id_ - 1))
+
+ target = dict(image_id=id_, annotations=annot)
+
+ return target
+
+ def _check_integrity(self) -> bool:
+ """Check integrity of dataset.
+
+ Returns:
+ True if dataset files are found and/or MD5s match, else False
+ """
+ image: bool = check_integrity(
+ os.path.join(self.root, self.image_meta["filename"]),
+ self.image_meta["md5"] if self.checksum else None,
+ )
+
+ # Annotations only needed for "positive" image set
+ target = True
+ if self.split == "positive":
+ target = check_integrity(
+ os.path.join(
+ self.root, "NWPU VHR-10 dataset", self.target_meta["filename"]
+ ),
+ self.target_meta["md5"] if self.checksum else None,
+ )
+
+ return image and target
+
+ def _download(self) -> None:
+ """Download the dataset and extract it."""
+ if self._check_integrity():
+ print("Files already downloaded and verified")
+ return
+
+ # Download images
+ download_and_extract_archive(
+ self.image_meta["url"],
+ self.root,
+ filename=self.image_meta["filename"],
+ md5=self.image_meta["md5"] if self.checksum else None,
+ )
+
+ # Annotations only needed for "positive" image set
+ if self.split == "positive":
+ # Download annotations
+ download_url(
+ self.target_meta["url"],
+ os.path.join(self.root, "NWPU VHR-10 dataset"),
+ self.target_meta["filename"],
+ self.target_meta["md5"] if self.checksum else None,
+ )
+
+ def plot(
+ self,
+ sample: Dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ show_feats: Optional[str] = "both",
+ box_alpha: float = 0.7,
+ mask_alpha: float = 0.7,
+ ) -> plt.Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`__getitem__`
+ suptitle: optional string to use as a suptitle
+ show_titles: flag indicating whether to show titles above each panel
+ show_feats: optional string to pick features to be shown: boxes, masks, both
+ box_alpha: alpha value of box
+ mask_alpha: alpha value of mask
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+
+ .. versionadded:: 0.4
+ """
+ assert show_feats in {"boxes", "masks", "both"}
+
+ if self.split == "negative":
+ plt.imshow(sample["image"].permute(1, 2, 0))
+ axs = plt.gca()
+ axs.axis("off")
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+ return plt.gcf()
+
+ image = sample["image"].permute(1, 2, 0).numpy()
+ boxes = sample["boxes"].cpu().numpy()
+ labels = sample["labels"].cpu().numpy()
+ if "masks" in sample:
+ masks = [mask.squeeze().cpu().numpy() for mask in sample["masks"]]
+
+ N_GT = len(boxes)
+
+ ncols = 1
+ show_predictions = "prediction_labels" in sample
+
+ if show_predictions:
+ show_pred_boxes = False
+ show_pred_masks = False
+ prediction_labels = sample["prediction_labels"].numpy()
+ prediction_scores = sample["prediction_scores"].numpy()
+ if "prediction_boxes" in sample:
+ prediction_boxes = sample["prediction_boxes"].numpy()
+ show_pred_boxes = True
+ if "prediction_masks" in sample:
+ prediction_masks = sample["prediction_masks"].numpy()
+ show_pred_masks = True
+
+ N_PRED = len(prediction_labels)
+ ncols += 1
+
+ # Display image
+ fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
+ if not isinstance(axs, np.ndarray):
+ axs = [axs]
+ axs[0].imshow(image)
+ axs[0].axis("off")
+
+ cm = plt.get_cmap("gist_rainbow")
+ for i in range(N_GT):
+ class_num = labels[i]
+ color = cm(class_num / len(self.categories))
+
+ # Add bounding boxes
+ x1, y1, x2, y2 = boxes[i]
+ if show_feats in {"boxes", "both"}:
+ p = patches.Rectangle(
+ (x1, y1),
+ x2 - x1,
+ y2 - y1,
+ linewidth=2,
+ alpha=box_alpha,
+ linestyle="dashed",
+ edgecolor=color,
+ facecolor="none",
+ )
+ axs[0].add_patch(p)
+
+ # Add labels
+ label = self.categories[class_num]
+ caption = label
+ axs[0].text(
+ x1, y1 - 8, caption, color="white", size=11, backgroundcolor="none"
+ )
+
+ # Add masks
+ if show_feats in {"masks", "both"} and "masks" in sample:
+ mask = masks[i]
+ contours = find_contours(mask, 0.5)
+ for verts in contours:
+ verts = np.fliplr(verts)
+ p = patches.Polygon(
+ verts, facecolor=color, alpha=mask_alpha, edgecolor="white"
+ )
+ axs[0].add_patch(p)
+
+ if show_titles:
+ axs[0].set_title("Ground Truth")
+
+ if show_predictions:
+ axs[1].imshow(image)
+ axs[1].axis("off")
+ for i in range(N_PRED):
+ score = prediction_scores[i]
+ if score < 0.5:
+ continue
+
+ # TODO: Check scores
+
+ class_num = prediction_labels[i]
+ color = cm(class_num / len(self.categories))
+
+ if show_pred_boxes:
+ # Add bounding boxes
+ x1, y1, x2, y2 = prediction_boxes[i]
+ p = patches.Rectangle(
+ (x1, y1),
+ x2 - x1,
+ y2 - y1,
+ linewidth=2,
+ alpha=box_alpha,
+ linestyle="dashed",
+ edgecolor=color,
+ facecolor="none",
+ )
+ axs[1].add_patch(p)
+
+ # Add labels
+ label = self.categories[class_num]
+ caption = f"{label} {score:.3f}"
+ axs[1].text(
+ x1,
+ y1 - 8,
+ caption,
+ color="white",
+ size=11,
+ backgroundcolor="none",
+ )
+
+ # TODO: Labels are dependent on boxes being shown
+ # Add masks
+ if show_pred_masks:
+ mask = prediction_masks[i]
+ contours = find_contours(mask, 0.5)
+ for verts in contours:
+ verts = np.fliplr(verts)
+ p = patches.Polygon(
+ verts, facecolor=color, alpha=mask_alpha, edgecolor="white"
+ )
+ axs[1].add_patch(p)
+
+ if show_titles:
+ axs[1].set_title("Prediction")
+
+ plt.tight_layout()
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+
+ return fig
diff --git a/train.py b/train.py
index 3f3fd7fb77f..8b4f1d0e153 100755
--- a/train.py
+++ b/train.py
@@ -29,6 +29,7 @@
So2SatDataModule,
TropicalCycloneDataModule,
UCMercedDataModule,
+ VHR10DataModule,
)
from torchgeo.trainers import (
BYOLTask,
@@ -58,6 +59,7 @@
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
"so2sat": (ClassificationTask, So2SatDataModule),
"ucmerced": (ClassificationTask, UCMercedDataModule),
+ "vhr10": (ObjectDetectionTask, VHR10DataModule),
}