diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 41a4960e354..12ca5eaef08 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -244,6 +244,11 @@ Potsdam .. autoclass:: Potsdam2D +ReforesTree +^^^^^^^^^^^ + +.. autoclass:: ReforesTree + RESISC45 ^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 7fca8bebec7..c4234270458 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -21,6 +21,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI `PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB `Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI +`ReforesTree`_,"OD, R",Aerial,100,"4,000x4,000",0.02,RGB `RESISC45`_,C,Google Earth,"31,500",45,256x256,0.2--30,RGB `Seasonal Contrast`_,T,Sentinel-2,100K--1M,,264x264,10,MSI `SEN12MS`_,S,"Sentinel-1/2, MODIS","180,662",33,256x256,10,"SAR, MSI" diff --git a/tests/data/reforestree/data.py b/tests/data/reforestree/data.py new file mode 100644 index 00000000000..27573cb6191 --- /dev/null +++ b/tests/data/reforestree/data.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import csv +import hashlib +import os +import shutil +from typing import List + +import numpy as np +from PIL import Image + +SIZE = 32 + +np.random.seed(0) + +PATHS = { + "images": [ + "tiles/Site1/Site1_RGB_0_0_0_4000_4000.png", + "tiles/Site2/Site2_RGB_0_0_0_4000_4000.png", + ], + "annotation": "mapping/final_dataset.csv", +} + + +def create_annotation(path: str, img_paths: List[str]) -> None: + cols = ["img_path", "xmin", "ymin", "xmax", "ymax", "group", "AGB"] + data = [] + for img_path in img_paths: + data.append( + [os.path.basename(img_path), 0, 0, SIZE / 2, SIZE / 2, "banana", 6.75] + ) + data.append( + [os.path.basename(img_path), SIZE / 2, SIZE / 2, SIZE, SIZE, "cacao", 6.75] + ) + + with open(path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(cols) + writer.writerows(data) + + +def create_img(path: str) -> None: + Z = np.random.rand(SIZE, SIZE, 3) * 255 + img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img.save(path) + + +if __name__ == "__main__": + data_root = "reforesTree" + + # remove old data + if os.path.isdir(data_root): + shutil.rmtree(data_root) + + # create imagery + for path in PATHS["images"]: + os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True) + create_img(os.path.join(data_root, path)) + + # create annotations + os.makedirs( + os.path.join(data_root, os.path.dirname(PATHS["annotation"])), exist_ok=True + ) + create_annotation(os.path.join(data_root, PATHS["annotation"]), PATHS["images"]) + + # compress data + shutil.make_archive(data_root, "zip", data_root) + + # Compute checksums + with open(data_root + ".zip", "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{data_root}: {md5}") diff --git a/tests/data/reforestree/reforesTree.zip b/tests/data/reforestree/reforesTree.zip new file mode 100644 index 00000000000..d1081a06a56 Binary files /dev/null and b/tests/data/reforestree/reforesTree.zip differ diff --git a/tests/data/reforestree/reforesTree/mapping/final_dataset.csv b/tests/data/reforestree/reforesTree/mapping/final_dataset.csv new file mode 100644 index 00000000000..9c71d73563a --- /dev/null +++ b/tests/data/reforestree/reforesTree/mapping/final_dataset.csv @@ -0,0 +1,5 @@ +img_path,xmin,ymin,xmax,ymax,group,AGB +Site1_RGB_0_0_0_4000_4000.png,0,0,16.0,16.0,banana,6.75 +Site1_RGB_0_0_0_4000_4000.png,16.0,16.0,32,32,cacao,6.75 +Site2_RGB_0_0_0_4000_4000.png,0,0,16.0,16.0,banana,6.75 +Site2_RGB_0_0_0_4000_4000.png,16.0,16.0,32,32,cacao,6.75 diff --git a/tests/data/reforestree/reforesTree/tiles/Site1/Site1_RGB_0_0_0_4000_4000.png b/tests/data/reforestree/reforesTree/tiles/Site1/Site1_RGB_0_0_0_4000_4000.png new file mode 100644 index 00000000000..95e37237fe8 Binary files /dev/null and b/tests/data/reforestree/reforesTree/tiles/Site1/Site1_RGB_0_0_0_4000_4000.png differ diff --git a/tests/data/reforestree/reforesTree/tiles/Site2/Site2_RGB_0_0_0_4000_4000.png b/tests/data/reforestree/reforesTree/tiles/Site2/Site2_RGB_0_0_0_4000_4000.png new file mode 100644 index 00000000000..69a0632fc1a Binary files /dev/null and b/tests/data/reforestree/reforesTree/tiles/Site2/Site2_RGB_0_0_0_4000_4000.png differ diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py new file mode 100644 index 00000000000..1337cfb18c3 --- /dev/null +++ b/tests/datasets/test_reforestree.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import builtins +import os +import shutil +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.monkeypatch import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import ReforesTree + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) + + +class TestReforesTree: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ReforesTree: + monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) + data_dir = os.path.join("tests", "data", "reforestree") + + url = os.path.join(data_dir, "reforesTree.zip") + + md5 = "387e04dbbb0aa803f72bd6d774409648" + + monkeypatch.setattr(ReforesTree, "url", url) + monkeypatch.setattr(ReforesTree, "md5", md5) + root = str(tmp_path) + transforms = nn.Identity() + return ReforesTree( + root=root, transforms=transforms, download=True, checksum=True + ) + + def test_already_downloaded(self, dataset: ReforesTree) -> None: + ReforesTree(root=dataset.root, download=True) + + def test_getitem(self, dataset: ReforesTree) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert isinstance(x["boxes"], torch.Tensor) + assert isinstance(x["agb"], torch.Tensor) + assert x["image"].shape[0] == 3 + assert x["image"].ndim == 3 + assert len(x["boxes"]) == 2 + + @pytest.fixture + def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: + import_orig = builtins.__import__ + package = "pandas" + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == package: + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + + def test_mock_missing_module( + self, dataset: ReforesTree, mock_missing_module: None + ) -> None: + with pytest.raises( + ImportError, + match="pandas is not installed and is required to use this dataset", + ): + ReforesTree(root=dataset.root) + + def test_len(self, dataset: ReforesTree) -> None: + assert len(dataset) == 2 + + def test_not_extracted(self, tmp_path: Path) -> None: + url = os.path.join("tests", "data", "reforestree", "reforesTree.zip") + shutil.copy(url, tmp_path) + ReforesTree(root=str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "reforesTree.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + ReforesTree(root=str(tmp_path), checksum=True) + + def test_not_found(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in"): + ReforesTree(str(tmp_path)) + + def test_plot(self, dataset: ReforesTree) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: ReforesTree) -> None: + x = dataset[0].copy() + x["prediction_boxes"] = x["boxes"].clone() + dataset.plot(x, suptitle="Prediction") + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index e7bbb25b9b8..13196988787 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -75,6 +75,7 @@ from .oscd import OSCD from .patternnet import PatternNet from .potsdam import Potsdam2D +from .reforestree import ReforesTree from .resisc45 import RESISC45 from .seco import SeasonalContrastS2 from .sen12ms import SEN12MS @@ -167,6 +168,7 @@ "PatternNet", "Potsdam2D", "RESISC45", + "ReforesTree", "SeasonalContrastS2", "SEN12MS", "So2Sat", diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py new file mode 100644 index 00000000000..7ab65b49bc4 --- /dev/null +++ b/torchgeo/datasets/reforestree.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""ReforesTree dataset.""" + +import glob +import os +from typing import Callable, Dict, List, Optional, Tuple + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from .geo import VisionDataset +from .utils import check_integrity, download_and_extract_archive, extract_archive + + +class ReforesTree(VisionDataset): + """ReforesTree dataset. + + The `ReforesTree `__ + dataset contains drone imagery that can be used for tree crown detection, + tree species classification and Aboveground Biomass (AGB) estimation. + + Dataset features: + + * 100 high resolution RGB drone images at 2 cm/pixel of size 4,000 x 4,000 px + * more than 4,600 tree crown box annotations + * tree crown matched with field measurements of diameter at breast height (DBH), + and computed AGB and carbon values + + Dataset format: + + * images are three-channel pngs + * annotations are csv file + + Dataset Classes: + + 0. other + 1. banana + 2. cacao + 3. citrus + 4. fruit + 5. timber + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/2201.11192 + + .. versionadded:: 0.3 + """ + + classes = ["other", "banana", "cacao", "citrus", "fruit", "timber"] + url = "https://zenodo.org/record/6813783/files/reforesTree.zip?download=1" + + md5 = "f6a4a1d8207aeaa5fbab7b21b683a302" + zipfilename = "reforesTree.zip" + + def __init__( + self, + root: str = "data", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new ReforesTree dataset instance. + + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + self.root = root + self.transforms = transforms + self.checksum = checksum + self.download = download + + self._verify() + + try: + import pandas as pd # noqa: F401 + except ImportError: + raise ImportError( + "pandas is not installed and is required to use this dataset" + ) + + self.files = self._load_files(self.root) + + self.annot_df = pd.read_csv(os.path.join(root, "mapping", "final_dataset.csv")) + + self.class2idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + filepath = self.files[index] + + image = self._load_image(filepath) + + boxes, labels, agb = self._load_target(filepath) + + sample = {"image": image, "boxes": boxes, "label": labels, "agb": agb} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def _load_files(self, root: str) -> List[str]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + list of dicts containing paths for each pair of image, annotation + """ + image_paths = sorted(glob.glob(os.path.join(root, "tiles", "**", "*.png"))) + + return image_paths + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with Image.open(path) as img: + array: "np.typing.NDArray[np.uint8]" = np.array(img) + tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target(self, filepath: str) -> Tuple[Tensor, ...]: + """Load boxes and labels for a single image. + + Args: + filepath: image tile filepath + + Returns: + dictionary containing boxes, label, and agb value + """ + tile_df = self.annot_df[self.annot_df["img_path"] == os.path.basename(filepath)] + + boxes = torch.Tensor(tile_df[["xmin", "ymin", "xmax", "ymax"]].values.tolist()) + labels = torch.Tensor( + [self.class2idx[label] for label in tile_df["group"].tolist()] + ) + agb = torch.Tensor(tile_df["AGB"].tolist()) + + return boxes, labels, agb + + def _verify(self) -> None: + """Checks the integrity of the dataset structure. + + Raises: + RuntimeError: if dataset is not found in root or is corrupted + """ + filepaths = [os.path.join(self.root, dir) for dir in ["tiles", "mapping"]] + if all([os.path.exists(filepath) for filepath in filepaths]): + return + + filepath = os.path.join(self.root, self.zipfilename) + if os.path.isfile(filepath): + if self.checksum and not check_integrity(filepath, self.md5): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automatically download the dataset." + ) + + # else download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset and extract it. + + Raises: + AssertionError: if the checksum does not match + """ + download_and_extract_archive( + self.url, + self.root, + filename=self.zipfilename, + md5=self.md5 if self.checksum else None, + ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = sample["image"].permute((1, 2, 0)).numpy() + ncols = 1 + showing_predictions = "prediction_boxes" in sample + if showing_predictions: + ncols += 1 + + fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) + if not showing_predictions: + axs = [axs] + + axs[0].imshow(image) + axs[0].axis("off") + + bboxes = [ + patches.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + linewidth=1, + edgecolor="r", + facecolor="none", + ) + for bbox in sample["boxes"].numpy() + ] + for bbox in bboxes: + axs[0].add_patch(bbox) + + if show_titles: + axs[0].set_title("Ground Truth") + + if showing_predictions: + axs[1].imshow(image) + axs[1].axis("off") + + pred_bboxes = [ + patches.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + linewidth=1, + edgecolor="r", + facecolor="none", + ) + for bbox in sample["prediction_boxes"].numpy() + ] + for bbox in pred_bboxes: + axs[1].add_patch(bbox) + + if show_titles: + axs[1].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig