Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ReforesTree dataset #582

Merged
merged 10 commits into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ Potsdam

.. autoclass:: Potsdam2D

ReforesTree
^^^^^^^^^^^

.. autoclass:: ReforesTree
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

RESISC45
^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
`PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB
`Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI
`ReforesTree`_,"OD, R",Aerial,100,"4,000x4,000",0.02,RGB
`RESISC45`_,C,Google Earth,"31,500",45,256x256,0.2--30,RGB
`Seasonal Contrast`_,T,Sentinel-2,100K--1M,,264x264,10,MSI
`SEN12MS`_,S,"Sentinel-1/2, MODIS","180,662",33,256x256,10,"SAR, MSI"
Expand Down
75 changes: 75 additions & 0 deletions tests/data/reforestree/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import csv
import hashlib
import os
import shutil
from typing import List

import numpy as np
from PIL import Image

SIZE = 32

np.random.seed(0)

PATHS = {
"images": [
"tiles/Site1/Site1_RGB_0_0_0_4000_4000.png",
"tiles/Site2/Site2_RGB_0_0_0_4000_4000.png",
],
"annotation": "mapping/final_dataset.csv",
}


def create_annotation(path: str, img_paths: List[str]) -> None:
cols = ["img_path", "xmin", "ymin", "xmax", "ymax", "group", "AGB"]
data = []
for img_path in img_paths:
data.append(
[os.path.basename(img_path), 0, 0, SIZE / 2, SIZE / 2, "banana", 6.75]
)
data.append(
[os.path.basename(img_path), SIZE / 2, SIZE / 2, SIZE, SIZE, "cacao", 6.75]
)

with open(path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(cols)
writer.writerows(data)


def create_img(path: str) -> None:
Z = np.random.rand(SIZE, SIZE, 3) * 255
img = Image.fromarray(Z.astype("uint8")).convert("RGB")
img.save(path)


if __name__ == "__main__":
data_root = "reforesTree"

# remove old data
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
if os.path.isdir(data_root):
shutil.rmtree(data_root)

# create imagery
for path in PATHS["images"]:
os.makedirs(os.path.join(data_root, os.path.dirname(path)), exist_ok=True)
create_img(os.path.join(data_root, path))

# create annotations
os.makedirs(
os.path.join(data_root, os.path.dirname(PATHS["annotation"])), exist_ok=True
)
create_annotation(os.path.join(data_root, PATHS["annotation"]), PATHS["images"])

# compress data
shutil.make_archive(data_root, "zip", data_root)

# Compute checksums
with open(data_root + ".zip", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{data_root}: {md5}")
Binary file added tests/data/reforestree/reforesTree.zip
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/data/reforestree/reforesTree/mapping/final_dataset.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
img_path,xmin,ymin,xmax,ymax,group,AGB
Site1_RGB_0_0_0_4000_4000.png,0,0,16.0,16.0,banana,6.75
Site1_RGB_0_0_0_4000_4000.png,16.0,16.0,32,32,cacao,6.75
Site2_RGB_0_0_0_4000_4000.png,0,0,16.0,16.0,banana,6.75
Site2_RGB_0_0_0_4000_4000.png,16.0,16.0,32,32,cacao,6.75
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
104 changes: 104 additions & 0 deletions tests/datasets/test_reforestree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import ReforesTree


def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)


class TestReforesTree:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ReforesTree:
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
data_dir = os.path.join("tests", "data", "reforestree")

url = os.path.join(data_dir, "reforesTree.zip")

md5 = "387e04dbbb0aa803f72bd6d774409648"

monkeypatch.setattr(ReforesTree, "url", url)
monkeypatch.setattr(ReforesTree, "md5", md5)
root = str(tmp_path)
transforms = nn.Identity()
return ReforesTree(
root=root, transforms=transforms, download=True, checksum=True
)

def test_already_downloaded(self, dataset: ReforesTree) -> None:
ReforesTree(root=dataset.root, download=True)

def test_getitem(self, dataset: ReforesTree) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["agb"], torch.Tensor)
assert x["image"].shape[0] == 3
assert x["image"].ndim == 3
assert len(x["boxes"]) == 2

@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
package = "pandas"

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == package:
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", mocked_import)

def test_mock_missing_module(
self, dataset: ReforesTree, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match="pandas is not installed and is required to use this dataset",
):
ReforesTree(root=dataset.root)

def test_len(self, dataset: ReforesTree) -> None:
assert len(dataset) == 2

def test_not_extracted(self, tmp_path: Path) -> None:
url = os.path.join("tests", "data", "reforestree", "reforesTree.zip")
shutil.copy(url, tmp_path)
ReforesTree(root=str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "reforesTree.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
ReforesTree(root=str(tmp_path), checksum=True)

def test_not_found(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
ReforesTree(str(tmp_path))

def test_plot(self, dataset: ReforesTree) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: ReforesTree) -> None:
x = dataset[0].copy()
x["prediction_boxes"] = x["boxes"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .oscd import OSCD
from .patternnet import PatternNet
from .potsdam import Potsdam2D
from .reforestree import ReforesTree
from .resisc45 import RESISC45
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
Expand Down Expand Up @@ -167,6 +168,7 @@
"PatternNet",
"Potsdam2D",
"RESISC45",
"ReforesTree",
"SeasonalContrastS2",
"SEN12MS",
"So2Sat",
Expand Down
Loading