Skip to content

Commit

Permalink
fix failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsleh committed Jun 17, 2022
1 parent d3057ca commit 37e31aa
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 1 deletion.
108 changes: 108 additions & 0 deletions tests/datasets/test_reforestree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import ReforesTree


def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)


class TestReforesTree:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ReforesTree:
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
data_dir = os.path.join("tests", "data", "reforestree")

url = os.path.join(data_dir, "data.zip")

md5 = "4f70885489000ed52de3223514179f63"

monkeypatch.setattr(ReforesTree, "url", url)
monkeypatch.setattr(ReforesTree, "md5", md5)
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[no-untyped-call]
return ReforesTree(
root=root, transforms=transforms, download=True, checksum=True
)

def test_already_downloaded(self, dataset: ReforesTree) -> None:
ReforesTree(root=dataset.root, download=True)

def test_getitem(self, dataset: ReforesTree) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["agb"], torch.Tensor)
assert x["image"].shape[0] == 3
assert x["image"].ndim == 3
assert len(x["boxes"]) == 2

@pytest.fixture(params=["pandas"])
def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str:
import_orig = builtins.__import__
package = str(request.param)

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == package:
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", mocked_import)
return package

def test_mock_missing_module(
self, dataset: ReforesTree, mock_missing_module: str
) -> None:
package = mock_missing_module

with pytest.raises(
ImportError,
match=f"{package} is not installed and is required to use this dataset",
):
ReforesTree(root=dataset.root)

def test_len(self, dataset: ReforesTree) -> None:
assert len(dataset) == 2

def test_not_extracted(self, tmp_path: Path) -> None:
url = os.path.join("tests", "data", "reforestree", "data.zip")
shutil.copy(url, tmp_path)
ReforesTree(root=str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "data.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
ReforesTree(root=str(tmp_path), checksum=True)

def test_not_found(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in."):
ReforesTree(str(tmp_path))

def test_plot(self, dataset: ReforesTree) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: ReforesTree) -> None:
x = dataset[0].copy()
x["prediction_boxes"] = x["boxes"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()
2 changes: 1 addition & 1 deletion torchgeo/datasets/reforestree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class ReforesTree(VisionDataset):
"""ReforesTree dataset.
The `ReforesTree https://github.com/gyrrei/ReforesTree`_
The `ReforesTree <https://github.com/gyrrei/ReforesTree>`_
dataset contains drone imagery that can be used for tree crown detection,
tree species classification and Aboveground Biomass (AGB) estimation.
Expand Down

0 comments on commit 37e31aa

Please sign in to comment.