Skip to content

Commit

Permalink
Add ReforesTree dataset (microsoft#582)
Browse files Browse the repository at this point in the history
* add ReforesTree dataset

* fix failing test

* suggested changes

* Update download URL

* Change zipfile name

* Minor fixes

* Remove f-string

* Fix dtype, remove unnecessary conversion

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
3 people authored Jul 9, 2022
1 parent 632d580 commit 8dc01f8
Show file tree
Hide file tree
Showing 10 changed files with 483 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ Potsdam

.. autoclass:: Potsdam2D

ReforesTree
^^^^^^^^^^^

.. autoclass:: ReforesTree

RESISC45
^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
`PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB
`Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI
`ReforesTree`_,"OD, R",Aerial,100,"4,000x4,000",0.02,RGB
`RESISC45`_,C,Google Earth,"31,500",45,256x256,0.2--30,RGB
`Seasonal Contrast`_,T,Sentinel-2,100K--1M,-,264x264,10,MSI
`SEN12MS`_,S,"Sentinel-1/2, MODIS","180,662",33,256x256,10,"SAR, MSI"
Expand Down
75 changes: 75 additions & 0 deletions tests/data/reforestree/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import csv
import hashlib
import os
import shutil
from typing import List

import numpy as np
from PIL import Image

SIZE = 32

np.random.seed(0)

PATHS = {
"images": [
"tiles/Site1/Site1_RGB_0_0_0_4000_4000.png",
"tiles/Site2/Site2_RGB_0_0_0_4000_4000.png",
],
"annotation": "mapping/final_dataset.csv",
}


def create_annotation(path: str, img_paths: List[str]) -> None:
cols = ["img_path", "xmin", "ymin", "xmax", "ymax", "group", "AGB"]
data = []
for img_path in img_paths:
data.append(
[os.path.basename(img_path), 0, 0, SIZE / 2, SIZE / 2, "banana", 6.75]
)
data.append(
[os.path.basename(img_path), SIZE / 2, SIZE / 2, SIZE, SIZE, "cacao", 6.75]
)

with open(path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(cols)
writer.writerows(data)


def create_img(path: str) -> None:
Z = np.random.rand(SIZE, SIZE, 3) * 255
img = Image.fromarray(Z.astype("uint8")).convert("RGB")
img.save(path)


if __name__ == "__main__":
data_root = "reforesTree"

# remove old data
if os.path.isdir(data_root):
shutil.rmtree(data_root)

# create imagery
for path in PATHS["images"]:
os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True)
create_img(os.path.join(data_root, path))

# create annotations
os.makedirs(
os.path.join(data_root, os.path.dirname(PATHS["annotation"])), exist_ok=True
)
create_annotation(os.path.join(data_root, PATHS["annotation"]), PATHS["images"])

# compress data
shutil.make_archive(data_root, "zip", data_root)

# Compute checksums
with open(data_root + ".zip", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{data_root}: {md5}")
Binary file added tests/data/reforestree/reforesTree.zip
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/data/reforestree/reforesTree/mapping/final_dataset.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
img_path,xmin,ymin,xmax,ymax,group,AGB
Site1_RGB_0_0_0_4000_4000.png,0,0,16.0,16.0,banana,6.75
Site1_RGB_0_0_0_4000_4000.png,16.0,16.0,32,32,cacao,6.75
Site2_RGB_0_0_0_4000_4000.png,0,0,16.0,16.0,banana,6.75
Site2_RGB_0_0_0_4000_4000.png,16.0,16.0,32,32,cacao,6.75
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
104 changes: 104 additions & 0 deletions tests/datasets/test_reforestree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import ReforesTree


def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)


class TestReforesTree:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ReforesTree:
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
data_dir = os.path.join("tests", "data", "reforestree")

url = os.path.join(data_dir, "reforesTree.zip")

md5 = "387e04dbbb0aa803f72bd6d774409648"

monkeypatch.setattr(ReforesTree, "url", url)
monkeypatch.setattr(ReforesTree, "md5", md5)
root = str(tmp_path)
transforms = nn.Identity()
return ReforesTree(
root=root, transforms=transforms, download=True, checksum=True
)

def test_already_downloaded(self, dataset: ReforesTree) -> None:
ReforesTree(root=dataset.root, download=True)

def test_getitem(self, dataset: ReforesTree) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["agb"], torch.Tensor)
assert x["image"].shape[0] == 3
assert x["image"].ndim == 3
assert len(x["boxes"]) == 2

@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
package = "pandas"

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == package:
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", mocked_import)

def test_mock_missing_module(
self, dataset: ReforesTree, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match="pandas is not installed and is required to use this dataset",
):
ReforesTree(root=dataset.root)

def test_len(self, dataset: ReforesTree) -> None:
assert len(dataset) == 2

def test_not_extracted(self, tmp_path: Path) -> None:
url = os.path.join("tests", "data", "reforestree", "reforesTree.zip")
shutil.copy(url, tmp_path)
ReforesTree(root=str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "reforesTree.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
ReforesTree(root=str(tmp_path), checksum=True)

def test_not_found(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
ReforesTree(str(tmp_path))

def test_plot(self, dataset: ReforesTree) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: ReforesTree) -> None:
x = dataset[0].copy()
x["prediction_boxes"] = x["boxes"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .oscd import OSCD
from .patternnet import PatternNet
from .potsdam import Potsdam2D
from .reforestree import ReforesTree
from .resisc45 import RESISC45
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
Expand Down Expand Up @@ -167,6 +168,7 @@
"PatternNet",
"Potsdam2D",
"RESISC45",
"ReforesTree",
"SeasonalContrastS2",
"SEN12MS",
"So2Sat",
Expand Down
Loading

0 comments on commit 8dc01f8

Please sign in to comment.